diff --git a/gsrp_smart_util.py b/gsrp_smart_util.py
index eb0a8613cf4aea4fe03424f65c3e6e1a9c22e08e..a8b470e1d4dc078d51358465dbbb29c8a47d178e 100644
--- a/gsrp_smart_util.py
+++ b/gsrp_smart_util.py
@@ -57,7 +57,7 @@ def op_tree(tree):
         for group in tree[i]:
             if group[:-1] in tree[i-1]:
                 j1 = tree[i-1].index(group[:-1])
-                j2 =  tree[0].index(group[-1:])
+                j2 = tree[0].index(group[-1:])
                 program[i].append(Operation('mul', j1, j2, i))
                 program[i-1][j1].lifetime = i
                 program[0][j2].lifetime = i
@@ -85,11 +85,29 @@ def dep_tdoa(tij, nind, ntot):
             i0 = j0 + 1
 
 
-def num_ind(i, j, n_ind):
+def num_ind_dep(i, j, n_ind):
     return j*(n_ind-1) + i-1 - (j*(j+1))//2 + n_ind
 
 
-def mul(mem1, mem2, cc, t_max, id1, id2, n_ind, mem_limit=np.infty):
+def inv_num_ind(k, n_ind):
+    j = int((2*n_ind + 1 - ((2*n_ind + 1)**2 - 8*k)**0.5)/2)
+    return k - j*n_ind + j*(j+1)//2 + 1, j
+
+
+def add_all(memory, tree, cc, t_max, n_ind, n_tot):
+    values = list(memory.values())
+    idx = np.meshgrid(*[np.arange(len(v[0])) for v in values], copy=False)
+    out_tij = np.empty((n_tot, idx[0].size), values[0][1].dtype)
+    for ((k0, k1), (v, t)), i in zip(memory.items(), idx):  # all ind_tdoa should be present in memory
+        out_tij[tree[k0][k1]] = np.take(t, i, axis=1).reshape(-1, i.size)
+    dep_tdoa(out_tij, n_ind, n_tot)
+    out_tij = out_tij[:, (np.abs(out_tij) <= t_max).all(0)]
+    val = cc[np.arange(n_tot)[:, None], out_tij].sum(0)
+    t = np.argmax(val)
+    return val[t], out_tij[:, t]
+
+
+def add(mem1, mem2, cc, t_max, id1, id2, n_ind, mem_limit=np.infty):
     # assume len(id2) == 1
     if (2 + len(id1)) * mem1[0].size * mem2[0].size * mem1[0].itemsize > mem_limit:
         return
@@ -98,13 +116,13 @@ def mul(mem1, mem2, cc, t_max, id1, id2, n_ind, mem_limit=np.infty):
     out_tij = np.empty((len(id1) + 1, len(idx1)), mem1[1].dtype)
     np.take(mem1[1], idx1, axis=1, out=out_tij[:-1])
     np.take(mem2[1], idx2, axis=1, out=out_tij[-1:])
-    mask = reduce_all(np.abs(out_tij[:-1] - out_tij[-1:]) <= t_max) # Faster than numpy for large array
+    mask = reduce_all(np.abs(out_tij[:-1] - out_tij[-1:]) <= t_max)  # Faster than numpy for large array
     out_tij = np.compress(mask, out_tij, axis=1)
     idx1, idx2 = idx1[mask], idx2[mask]
     out_val = mem1[0][idx1] + mem2[0][idx2]
     tij_dep = out_tij[:-1] - out_tij[-1:]
     tij_dep *= np.array([1 if i > id2 else -1 for i in id1])[:, np.newaxis]
-    ch_dep = np.array([num_ind(i, id2, n_ind) if i > id2 else num_ind(id2, i, n_ind) for i in id1])
+    ch_dep = np.array([num_ind_dep(i, id2, n_ind) if i > id2 else num_ind_dep(id2, i, n_ind) for i in id1])
     out_val += cc[ch_dep[:, np.newaxis], tij_dep].sum(0)
     return out_val, out_tij
 
@@ -132,12 +150,19 @@ def constrained_argmax(mem, cc, tij_ind, curr_tij, used_tij, t_max, n_ind):
     for u in used_tij:
         for i, c in enumerate(curr_tij):
             if u < c:
-                mem_val += cc[num_ind(c, u, n_ind), mem_tij[i] - tij_ind[u]]
+                mem_val += cc[num_ind_dep(c, u, n_ind), mem_tij[i] - tij_ind[u]]
             else:
-                mem_val += cc[num_ind(u, c, n_ind), tij_ind[u] - mem_tij[i]]
+                mem_val += cc[num_ind_dep(u, c, n_ind), tij_ind[u] - mem_tij[i]]
     return mem_tij[:, np.argmax(mem_val)]
 
 
+def truncated_argmax(cc, t_max):
+    argmax_plus = np.argmax(cc[:, :t_max], axis=1)
+    argmax_minus = np.argmax(cc[:, -t_max:], axis=1) - t_max
+    x = np.arange(len(argmax_minus))
+    return np.where(cc[x, argmax_minus] < cc[x, argmax_plus], argmax_plus, argmax_minus)
+
+
 def _get_mem_size(memory):
     return sum(len(o[0].T) for o in memory.values())
 
@@ -150,45 +175,76 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list, verbose=False
     memory = dict()
     with np.errstate(divide='ignore'):
         cc = np.log10(cc)
-    val = cc[:, 0].sum()
+    out_val = cc[:, 0].sum()
     tij = np.zeros(n_tot, int)
+    out_tij = tij.copy()
+    tij[:n_ind] = truncated_argmax(cc[:n_ind], t_max)
+    dep_tdoa(tij, n_ind, n_tot)
+    if np.all(np.abs(tij) <= t_max):
+        val = cc[np.arange(n_tot), tij].sum()
+        if val > out_val:
+            out_tij = tij.copy()
+            out_val = val
+    if out_val < -1:  # out_val < -20dB, trying with other channel as independent base
+        _sort = (np.argsort(np.asarray(np.triu_indices(n_ind + 1, 1)).ravel()) % tij.size)[n_ind:]
+        # _sort should be precomputed since fixed by n_ind, but is only 40µs for n_ind =4
+        ch = np.argmin(cc[np.arange(n_tot), tij][_sort].reshape(-1, n_ind).sum(1))
+        other_ch = np.asarray([i for i in range(n_ind) if i != ch])
+        dep = [num_ind_dep(i, ch, n_ind) if i > ch else num_ind_dep(ch, i, n_ind) for i in other_ch]
+        t_dep = truncated_argmax(cc[dep], t_max)
+        tij[other_ch] = (-1)**(other_ch < ch) * t_dep + tij[ch]
+        dep_tdoa(tij, n_ind, n_tot)
+        if np.all(np.abs(tij) <= t_max):
+            val = cc[np.arange(n_tot), tij].sum()
+            if val > out_val:
+                out_tij = tij.copy()
+                out_val = val  # if this point is reach, it might be useful to swap(0, ch)
     for i, step in enumerate(program):
         # increase dimensions
+        for k, v in memory.items():
+            curr_tij = frozenset(tree[k[0]][k[1]])
+            memory[k] = mask_val(v, out_val - sum(m for tij, m in max_val.items() if curr_tij.isdisjoint(tij)))
+        if i and np.prod([v[0].size for v in memory.values()], dtype=np.float64) <= 1024:  # float to prevent overflow
+            return *add_all(memory, tree, cc, t_max, n_ind, n_tot), 1
         for j, op in enumerate(step):
             if op == 'mem':
                 if i == 0:
-                    memory[(0, j)] = mask_val((np.concatenate((cc[op.right, -t_max:], cc[op.right, :t_max + 1])),
-                                               np.arange(-t_max, t_max + 1, dtype=np.int32)[np.newaxis]), val)
+                    memory[0, j] = mask_val((np.concatenate((cc[op.right, -t_max:], cc[op.right, :t_max + 1])),
+                                             np.arange(-t_max, t_max + 1, dtype=np.int32)[np.newaxis]), out_val)
                 else:
-                    memory[(i, j)] = mask_val(memory[(op.left, op.right)], val)
+                    memory[i, j] = memory[op.left, op.right]
             else:  # op == 'mul'
-                memory[(i, j)] = mul(mask_val(memory[(i-1, op.left)], val), mask_val(memory[(0, op.right)], val),
-                                     cc, t_max, tree[i - 1][op.left], tree[0][op.right][0], n_ind,
-                                     mem_limit=mem_limit - _get_mem_usage(memory))
-                if memory[(i, j)] is None:  # means that the memory limit has been reach
-                    return val, tij, 0
+                memory[i, j] = add(memory[i - 1, op.left], memory[0, op.right], cc, t_max, tree[i - 1][op.left],
+                                   tree[0][op.right][0], n_ind, mem_limit=mem_limit - _get_mem_usage(memory))
+                if memory[i, j] is None:  # means that the memory limit has been reach
+                    return out_val, out_tij, 0
         # find potential maximum
         tij[:] = 0
         done_tij = set()
+        max_val = {frozenset(tree[i][j]): memory[i, j][0].max() for j in range(len(step))}
         try:
             for j in range(len(step)):
                 curr_tij = tree[i][j]
-                tij[curr_tij] = constrained_argmax(memory[(i, j)], cc, tij[:n_ind], curr_tij, done_tij, t_max, n_ind)
+                tij[curr_tij] = constrained_argmax(memory[i, j], cc, tij[:n_ind], curr_tij, done_tij, t_max, n_ind)
                 done_tij.update(curr_tij)
             dep_tdoa(tij, n_ind, n_tot)
-            val = max(cc[np.arange(n_tot), tij].sum(), val)
+            val = cc[np.arange(n_tot), tij].sum()
+            if val > out_val:
+                out_tij = tij.copy()
+                out_val = val
         except ValueError:  # search of potential maxima ended outside of possible values
-            tij_min, tij_max = memory[(i, j)][1].min(), memory[(i, j)][1].max()
+            tij_min, tij_max = memory[i, j][1].min(), memory[i, j][1].max()
             for k in range(j):
-                memory[(i, k)] = mask_lim(memory[(i, k)], tij_min, tij_max, t_max)
+                memory[i, k] = mask_lim(memory[i, k], tij_min, tij_max, t_max)
 
         if verbose:
             mem_size = _get_mem_size(memory)
-            tqdm.write(f'TDOA: {tij}, val: {20*val:7.3f}dB, mem size: {mem_size} items, {_get_mem_usage(memory):3.2e} octets,'
-                       f' {100 * mem_size / (n_ind // (i + 1)) / (2 * t_max + 1) ** (i + 1):.4}%')
+            tqdm.write(f'TDOA: {out_tij}, val: {20*out_val:7.3f}dB, mem size: {mem_size} items, '
+                       f'{_get_mem_usage(memory):3.2e} octets, '
+                       f'{100 * mem_size / (n_ind // (i + 1)) / (2 * t_max + 1) ** (i + 1):.4}%')
 
         # Mem clean up
         for p in clean_list[i]:
             del memory[p]
 
-    return val, tij, 1
+    return out_val, out_tij, 1