diff --git a/gsrp_smart_util.py b/gsrp_smart_util.py index 3c82fd6e5dae620dc890af9b845aba2bbb322ab5..d7ac8ab9b2ad803a71a6303a4d786b63889ad6dd 100644 --- a/gsrp_smart_util.py +++ b/gsrp_smart_util.py @@ -82,8 +82,8 @@ def dep_tdoa(tij, nind, ntot): i0 = j0 + 1 -def num_ind(i, j, nind): - return j*(nind-1) + i-1 - (j*(j+1))//2 +def num_ind(i, j, n_ind): + return j*(n_ind-1) + i-1 - (j*(j+1))//2 + n_ind def mul(mem1, mem2, cc, t_max, id1, id2, n_ind): @@ -97,7 +97,7 @@ def mul(mem1, mem2, cc, t_max, id1, id2, n_ind): out_val = mem1[0][idx1] * mem2[0][idx2] tij_dep = out_tij[:-1] - out_tij[-1:] tij_dep *= np.array([1 if i > id2 else -1 for i in id1])[:, np.newaxis] - ch_dep = np.array([num_ind(i, id2, n_ind) if i > id2 else num_ind(id2, i, n_ind) for i in id1]) + n_ind + ch_dep = np.array([num_ind(i, id2, n_ind) if i > id2 else num_ind(id2, i, n_ind) for i in id1]) out_val *= cc[ch_dep[:, np.newaxis], tij_dep].prod(0) return out_val, out_tij @@ -107,10 +107,21 @@ def mask_val(mem, val): return mem[0][mask], mem[1][:, mask] -def constrained_argmax(mem, tij_ind, t_max): +def mask_lim(mem, tij_min, tij_max, t_max): + mask = ((tij_max - t_max <= mem[1]) & (mem[1] <= tij_min + t_max)).all(0) + return mem[0][mask], mem[1][:, mask] + + +def constrained_argmax(mem, cc, tij_ind, curr_tij, used_tij, t_max, n_ind): min_t, max_t = tij_ind.min(), tij_ind.max() - mask = ((max_t - t_max <= mem[1]) & (mem[1] <= min_t + t_max)).all(0) - return mem[1][:, mask][:, np.argmax(mem[0][mask])] + mem_val, mem_tij = mask_lim(mem, min_t, max_t, t_max) + for u in used_tij: + for i, c in enumerate(curr_tij): + if u < c: + mem_val *= cc[num_ind(c, u, n_ind), mem_tij[i] - tij_ind[u]] + else: + mem_val *= cc[num_ind(u, c, n_ind), tij_ind[u] - mem_tij[i]] + return mem_tij[:, np.argmax(mem_val)] def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list): @@ -131,11 +142,20 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list): cc, t_max, tree[i - 1][op.left], tree[0][op.right][0], n_ind) # find potential maximum tij[:] = 0 - for j in range(len(step)): - tij[tree[i][j]] = constrained_argmax(memory[(i, j)], tij[:n_ind], t_max) - dep_tdoa(tij, n_ind, n_tot) - val = cc[np.arange(n_tot), tij].prod() - # print('tdoa:', tij, 'val:', val, 'mem size:', (lambda x: f'{x} ({100 * x / (2 * t_max + 1) ** (i+1)}%)')(sum(len(o[0].T) for o in memory.values()))) + done_tij = set() + try: + for j in range(len(step)): + curr_tij = tree[i][j] + tij[curr_tij] = constrained_argmax(memory[(i, j)], cc, tij[:n_ind], curr_tij, done_tij, t_max, n_ind) + done_tij.update(curr_tij) + dep_tdoa(tij, n_ind, n_tot) + val = max(cc[np.arange(n_tot), tij].prod(), val) + except ValueError: # search for potential maxima outside of possible values + tij_min, tij_max = memory[(i, j)][1].min(), memory[(i, j)][1].max() + for k in range(j): + memory[(i, k)] = mask_lim(memory[(i, k)], tij_min, tij_max, t_max) + + # print('tdoa:', tij, 'val:', val, 'mem size:', (lambda x: f'{x} ({100 * x / (n_ind//(i+1)) / (2 * t_max + 1) ** (i+1)}%)')(sum(len(o[0].T) for o in memory.values()))) # Mem clean up for p in clean_list[i]: diff --git a/gsrp_tdoa_hyperres.py b/gsrp_tdoa_hyperres.py index 313d6a8f75a9b57a308b76e20662a61816a0c9ed..334868d830f540ed06fc7e2c0d4fa16fde2ec923 100755 --- a/gsrp_tdoa_hyperres.py +++ b/gsrp_tdoa_hyperres.py @@ -38,7 +38,7 @@ def intlist(s): def slicer(down, up, ndim, n): index = np.mgrid[ndim * [slice(0, n)]] - bounds = np.linspace(down, up, n + 1).astype(np.int) + bounds = np.linspace(down, up, n + 1).astype(int) slices = np.asarray([slice(a, b) for a, b in zip(bounds[:-1], bounds[1:])]) return slices[index].reshape(ndim, -1).T @@ -195,7 +195,7 @@ def main(args): else: np.savetxt(args.outfile, result1, delimiter=',') if not args.no_hyperres: - np.savetxt((lambda x1, x2, x3: x1 + '_2' + x2 + x3)(*args.outfile.rpartition('.', 1)), + np.savetxt((lambda x1, x2, x3: x1 + '_2' + x2 + x3)(*args.outfile.rpartition('.')), result2, delimiter=',') print("Done.") return 0 @@ -276,6 +276,6 @@ if __name__ == "__main__": sys.exit(main(args)) - except KeyError as e: + except Exception as e: print(e) sys.exit(2)