From b17cfe1d64b4b5469703388bf01468e63b98563a Mon Sep 17 00:00:00 2001
From: ferrari <maxence.ferrari@gmail.com>
Date: Wed, 9 Feb 2022 14:00:42 +0100
Subject: [PATCH] Fix smart method search space

---
 gsrp_smart_util.py    | 9 ++++++++-
 gsrp_tdoa_hyperres.py | 3 ++-
 2 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/gsrp_smart_util.py b/gsrp_smart_util.py
index 5cd7863..814b2c3 100644
--- a/gsrp_smart_util.py
+++ b/gsrp_smart_util.py
@@ -107,6 +107,12 @@ def mask_val(mem, val):
     return mem[0][mask], mem[1][mask]
 
 
+def constrained_argmax(mem, tij_ind, t_max):
+    min_t, max_t = tij_ind.min(), tij_ind.max()
+    mask = ((max_t - t_max <= mem[1]) & (mem[1] <= min_t + t_max)).all(-1)
+    return mem[1][mask][np.argmax(mem[0][mask])]
+
+
 def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list):
     memory = dict()
     val = 0
@@ -124,8 +130,9 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list):
                 memory[(i, j)] = mul(mask_val(memory[(i-1, op.left)], val), mask_val(memory[(0, op.right)], val),
                                      cc, t_max, tree[i - 1][op.left], tree[0][op.right][0], n_ind)
         # find potential maximum
+        tij[:] = 0
         for j in range(len(step)):
-            tij[tree[i][j]] = memory[(i, j)][1][np.argmax(memory[(i, j)][0])]
+            tij[tree[i][j]] = constrained_argmax(memory[(i, j)], tij[:n_ind], t_max)
         dep_tdoa(tij, n_ind, n_tot)
         val = cc[tij, np.arange(n_tot)].prod()
         # print('tdoa:', tij, 'val:', val, 'mem size:', (lambda x: f'{x} ({100 * x / (2 * t_max + 1) ** i}%)')(sum(len(o[0]) for o in memory.values())))
diff --git a/gsrp_tdoa_hyperres.py b/gsrp_tdoa_hyperres.py
index 419a754..6f1e07c 100755
--- a/gsrp_tdoa_hyperres.py
+++ b/gsrp_tdoa_hyperres.py
@@ -120,11 +120,12 @@ def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True):
             cc /= maxs
             val, tij = smart_gsrp(cc.T, num_channels - 1, num_channel_pairs, cc_size // 2, tree, program, clean_list)
             tdoas[i, 0], tdoas[i, 1:] = np.log10(val * maxs.prod()), tij[:(num_channels - 1)]
+            cc *= maxs
         else:
             raise ValueError(f'Unknown mode {mode}')
 
         if hyper:
-            tdoas[i, 0], tdoas2[i, 1:] = _hyperres(tdoas[i, 1:], cc)
+            tdoas2[i, 0], tdoas2[i, 1:] = _hyperres(tdoas[i, 1:], cc)
 
     if hyper:
         return np.hstack((np.expand_dims(pos, -1), tdoas)), np.hstack((np.expand_dims(pos, -1), tdoas2))
-- 
GitLab