diff --git a/gsrp_smart_util.py b/gsrp_smart_util.py index a8b470e1d4dc078d51358465dbbb29c8a47d178e..25e6829270ebc6f9827ec1de8b0d503c3b8b638e 100644 --- a/gsrp_smart_util.py +++ b/gsrp_smart_util.py @@ -157,7 +157,7 @@ def constrained_argmax(mem, cc, tij_ind, curr_tij, used_tij, t_max, n_ind): def truncated_argmax(cc, t_max): - argmax_plus = np.argmax(cc[:, :t_max], axis=1) + argmax_plus = np.argmax(cc[:, :t_max+1], axis=1) argmax_minus = np.argmax(cc[:, -t_max:], axis=1) - t_max x = np.arange(len(argmax_minus)) return np.where(cc[x, argmax_minus] < cc[x, argmax_plus], argmax_plus, argmax_minus) @@ -174,11 +174,12 @@ def _get_mem_usage(memory): def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list, verbose=False, mem_limit=np.infty): memory = dict() with np.errstate(divide='ignore'): - cc = np.log10(cc) + cc = np.concatenate((np.log10(cc[:, :t_max+1]), np.log10(cc[:, -t_max:])), axis=1) out_val = cc[:, 0].sum() tij = np.zeros(n_tot, int) out_tij = tij.copy() - tij[:n_ind] = truncated_argmax(cc[:n_ind], t_max) + tij[:n_ind] = np.argmax(cc[:n_ind], 1) + tij[:n_ind][tij[:n_ind] > t_max] -= cc.shape[1] dep_tdoa(tij, n_ind, n_tot) if np.all(np.abs(tij) <= t_max): val = cc[np.arange(n_tot), tij].sum() @@ -191,7 +192,8 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list, verbose=False ch = np.argmin(cc[np.arange(n_tot), tij][_sort].reshape(-1, n_ind).sum(1)) other_ch = np.asarray([i for i in range(n_ind) if i != ch]) dep = [num_ind_dep(i, ch, n_ind) if i > ch else num_ind_dep(ch, i, n_ind) for i in other_ch] - t_dep = truncated_argmax(cc[dep], t_max) + t_dep = np.argmax(cc[dep], 1) + t_dep[t_dep > t_max] -= cc.shape[1] tij[other_ch] = (-1)**(other_ch < ch) * t_dep + tij[ch] dep_tdoa(tij, n_ind, n_tot) if np.all(np.abs(tij) <= t_max):