diff --git a/gsrp_smart_util.py b/gsrp_smart_util.py index 5cd7863005b1aaa9e3f3238c18b25d78a729e5ed..814b2c32d253a9dec9015bac5e05569361e92eba 100644 --- a/gsrp_smart_util.py +++ b/gsrp_smart_util.py @@ -107,6 +107,12 @@ def mask_val(mem, val): return mem[0][mask], mem[1][mask] +def constrained_argmax(mem, tij_ind, t_max): + min_t, max_t = tij_ind.min(), tij_ind.max() + mask = ((max_t - t_max <= mem[1]) & (mem[1] <= min_t + t_max)).all(-1) + return mem[1][mask][np.argmax(mem[0][mask])] + + def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list): memory = dict() val = 0 @@ -124,8 +130,9 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list): memory[(i, j)] = mul(mask_val(memory[(i-1, op.left)], val), mask_val(memory[(0, op.right)], val), cc, t_max, tree[i - 1][op.left], tree[0][op.right][0], n_ind) # find potential maximum + tij[:] = 0 for j in range(len(step)): - tij[tree[i][j]] = memory[(i, j)][1][np.argmax(memory[(i, j)][0])] + tij[tree[i][j]] = constrained_argmax(memory[(i, j)], tij[:n_ind], t_max) dep_tdoa(tij, n_ind, n_tot) val = cc[tij, np.arange(n_tot)].prod() # print('tdoa:', tij, 'val:', val, 'mem size:', (lambda x: f'{x} ({100 * x / (2 * t_max + 1) ** i}%)')(sum(len(o[0]) for o in memory.values()))) diff --git a/gsrp_tdoa_hyperres.py b/gsrp_tdoa_hyperres.py index 419a75474507de3508df75774fb014d2b1f92a80..6f1e07c6a0402b4d1ed139c96d81308b8bc6686a 100755 --- a/gsrp_tdoa_hyperres.py +++ b/gsrp_tdoa_hyperres.py @@ -120,11 +120,12 @@ def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True): cc /= maxs val, tij = smart_gsrp(cc.T, num_channels - 1, num_channel_pairs, cc_size // 2, tree, program, clean_list) tdoas[i, 0], tdoas[i, 1:] = np.log10(val * maxs.prod()), tij[:(num_channels - 1)] + cc *= maxs else: raise ValueError(f'Unknown mode {mode}') if hyper: - tdoas[i, 0], tdoas2[i, 1:] = _hyperres(tdoas[i, 1:], cc) + tdoas2[i, 0], tdoas2[i, 1:] = _hyperres(tdoas[i, 1:], cc) if hyper: return np.hstack((np.expand_dims(pos, -1), tdoas)), np.hstack((np.expand_dims(pos, -1), tdoas2))