diff --git a/gsrp_smart_util.py b/gsrp_smart_util.py
index a8b470e1d4dc078d51358465dbbb29c8a47d178e..25e6829270ebc6f9827ec1de8b0d503c3b8b638e 100644
--- a/gsrp_smart_util.py
+++ b/gsrp_smart_util.py
@@ -157,7 +157,7 @@ def constrained_argmax(mem, cc, tij_ind, curr_tij, used_tij, t_max, n_ind):
 
 
 def truncated_argmax(cc, t_max):
-    argmax_plus = np.argmax(cc[:, :t_max], axis=1)
+    argmax_plus = np.argmax(cc[:, :t_max+1], axis=1)
     argmax_minus = np.argmax(cc[:, -t_max:], axis=1) - t_max
     x = np.arange(len(argmax_minus))
     return np.where(cc[x, argmax_minus] < cc[x, argmax_plus], argmax_plus, argmax_minus)
@@ -174,11 +174,12 @@ def _get_mem_usage(memory):
 def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list, verbose=False, mem_limit=np.infty):
     memory = dict()
     with np.errstate(divide='ignore'):
-        cc = np.log10(cc)
+        cc = np.concatenate((np.log10(cc[:, :t_max+1]), np.log10(cc[:, -t_max:])), axis=1)
     out_val = cc[:, 0].sum()
     tij = np.zeros(n_tot, int)
     out_tij = tij.copy()
-    tij[:n_ind] = truncated_argmax(cc[:n_ind], t_max)
+    tij[:n_ind] = np.argmax(cc[:n_ind], 1)
+    tij[:n_ind][tij[:n_ind] > t_max] -= cc.shape[1]
     dep_tdoa(tij, n_ind, n_tot)
     if np.all(np.abs(tij) <= t_max):
         val = cc[np.arange(n_tot), tij].sum()
@@ -191,7 +192,8 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list, verbose=False
         ch = np.argmin(cc[np.arange(n_tot), tij][_sort].reshape(-1, n_ind).sum(1))
         other_ch = np.asarray([i for i in range(n_ind) if i != ch])
         dep = [num_ind_dep(i, ch, n_ind) if i > ch else num_ind_dep(ch, i, n_ind) for i in other_ch]
-        t_dep = truncated_argmax(cc[dep], t_max)
+        t_dep = np.argmax(cc[dep], 1)
+        t_dep[t_dep > t_max] -= cc.shape[1]
         tij[other_ch] = (-1)**(other_ch < ch) * t_dep + tij[ch]
         dep_tdoa(tij, n_ind, n_tot)
         if np.all(np.abs(tij) <= t_max):