diff --git a/gsrp_smart_util.py b/gsrp_smart_util.py
index 814b2c32d253a9dec9015bac5e05569361e92eba..3c82fd6e5dae620dc890af9b845aba2bbb322ab5 100644
--- a/gsrp_smart_util.py
+++ b/gsrp_smart_util.py
@@ -88,29 +88,29 @@ def num_ind(i, j, nind):
 
 def mul(mem1, mem2, cc, t_max, id1, id2, n_ind):
     # assume len(id2) == 1
-    idx1, idx2 = np.meshgrid(np.arange(len(mem1[0])), np.arange(len(mem2[1])))
+    idx1, idx2 = np.meshgrid(np.arange(len(mem1[0])), np.arange(len(mem2[0])))
     idx1, idx2 = idx1.ravel(), idx2.ravel()
-    out_tij = np.concatenate((mem1[1][idx1], mem2[1][idx2]), axis=-1)
-    mask = (np.abs(out_tij[:, :-1] - out_tij[:, -1:]) <= t_max).all(-1)
-    out_tij = out_tij[mask]
+    out_tij = np.concatenate((mem1[1][:, idx1], mem2[1][:, idx2]), axis=0)
+    mask = (np.abs(out_tij[:-1] - out_tij[-1:]) <= t_max).all(0)
+    out_tij = out_tij[:, mask]
     idx1, idx2 = idx1[mask], idx2[mask]
     out_val = mem1[0][idx1] * mem2[0][idx2]
-    tij_dep = out_tij[:, :-1] - out_tij[:, -1:]
-    tij_dep *= np.array([1 if i > id2 else -1 for i in id1])
+    tij_dep = out_tij[:-1] - out_tij[-1:]
+    tij_dep *= np.array([1 if i > id2 else -1 for i in id1])[:, np.newaxis]
     ch_dep = np.array([num_ind(i, id2, n_ind) if i > id2 else num_ind(id2, i, n_ind) for i in id1]) + n_ind
-    out_val *= cc[tij_dep, ch_dep].prod(-1)
+    out_val *= cc[ch_dep[:, np.newaxis], tij_dep].prod(0)
     return out_val, out_tij
 
 
 def mask_val(mem, val):
     mask = mem[0] >= val
-    return mem[0][mask], mem[1][mask]
+    return mem[0][mask], mem[1][:, mask]
 
 
 def constrained_argmax(mem, tij_ind, t_max):
     min_t, max_t = tij_ind.min(), tij_ind.max()
-    mask = ((max_t - t_max <= mem[1]) & (mem[1] <= min_t + t_max)).all(-1)
-    return mem[1][mask][np.argmax(mem[0][mask])]
+    mask = ((max_t - t_max <= mem[1]) & (mem[1] <= min_t + t_max)).all(0)
+    return mem[1][:, mask][:, np.argmax(mem[0][mask])]
 
 
 def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list):
@@ -122,8 +122,8 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list):
         for j, op in enumerate(step):
             if op == 'mem':
                 if i == 0:
-                    memory[(0, j)] = (np.concatenate((cc[-t_max:, op.right], cc[:t_max + 1, op.right])),
-                                      np.arange(-t_max, t_max + 1).reshape(-1, 1))
+                    memory[(0, j)] = (np.concatenate((cc[op.right, -t_max:], cc[op.right, :t_max + 1])),
+                                      np.arange(-t_max, t_max + 1)[np.newaxis])
                 else:
                     memory[(i, j)] = mask_val(memory[(op.left, op.right)], val)
             else:  # op == 'mul'
@@ -134,8 +134,8 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list):
         for j in range(len(step)):
             tij[tree[i][j]] = constrained_argmax(memory[(i, j)], tij[:n_ind], t_max)
         dep_tdoa(tij, n_ind, n_tot)
-        val = cc[tij, np.arange(n_tot)].prod()
-        # print('tdoa:', tij, 'val:', val, 'mem size:', (lambda x: f'{x} ({100 * x / (2 * t_max + 1) ** i}%)')(sum(len(o[0]) for o in memory.values())))
+        val = cc[np.arange(n_tot), tij].prod()
+        # print('tdoa:', tij, 'val:', val, 'mem size:', (lambda x: f'{x} ({100 * x / (2 * t_max + 1) ** (i+1)}%)')(sum(len(o[0].T) for o in memory.values())))
 
         # Mem clean up
         for p in clean_list[i]:
diff --git a/gsrp_tdoa_hyperres.py b/gsrp_tdoa_hyperres.py
index 6f1e07c6a0402b4d1ed139c96d81308b8bc6686a..803d30ffefc0e6f2989aedeccc8973a07ff5a378 100755
--- a/gsrp_tdoa_hyperres.py
+++ b/gsrp_tdoa_hyperres.py
@@ -118,7 +118,7 @@ def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True):
         elif mode == 'smart':
             maxs = cc.max(1, keepdims=True)
             cc /= maxs
-            val, tij = smart_gsrp(cc.T, num_channels - 1, num_channel_pairs, cc_size // 2, tree, program, clean_list)
+            val, tij = smart_gsrp(cc, num_channels - 1, num_channel_pairs, cc_size // 2, tree, program, clean_list)
             tdoas[i, 0], tdoas[i, 1:] = np.log10(val * maxs.prod()), tij[:(num_channels - 1)]
             cc *= maxs
         else:
@@ -187,7 +187,7 @@ def main(args):
         result1 = results
     else:
         result1, result2 =results
-    # compute additional non-independent TDOAs
+    # compute additional non-independent TDOAs # TODO dont recompute them
     additional1 = [(b - a) for a, b in itertools.combinations(result1.T[2:], 2)]
     result1 = np.hstack((result1,) + tuple(a[:, np.newaxis] for a in additional1))
     if not args.no_hyperres: