Skip to content
Snippets Groups Projects
Commit b289d8ed authored by ferrari's avatar ferrari
Browse files

Improved speed and memory consumption of smart mode

parent cdd646f5
No related branches found
No related tags found
No related merge requests found
......@@ -157,7 +157,7 @@ def constrained_argmax(mem, cc, tij_ind, curr_tij, used_tij, t_max, n_ind):
def truncated_argmax(cc, t_max):
argmax_plus = np.argmax(cc[:, :t_max], axis=1)
argmax_plus = np.argmax(cc[:, :t_max+1], axis=1)
argmax_minus = np.argmax(cc[:, -t_max:], axis=1) - t_max
x = np.arange(len(argmax_minus))
return np.where(cc[x, argmax_minus] < cc[x, argmax_plus], argmax_plus, argmax_minus)
......@@ -174,11 +174,12 @@ def _get_mem_usage(memory):
def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list, verbose=False, mem_limit=np.infty):
memory = dict()
with np.errstate(divide='ignore'):
cc = np.log10(cc)
cc = np.concatenate((np.log10(cc[:, :t_max+1]), np.log10(cc[:, -t_max:])), axis=1)
out_val = cc[:, 0].sum()
tij = np.zeros(n_tot, int)
out_tij = tij.copy()
tij[:n_ind] = truncated_argmax(cc[:n_ind], t_max)
tij[:n_ind] = np.argmax(cc[:n_ind], 1)
tij[:n_ind][tij[:n_ind] > t_max] -= cc.shape[1]
dep_tdoa(tij, n_ind, n_tot)
if np.all(np.abs(tij) <= t_max):
val = cc[np.arange(n_tot), tij].sum()
......@@ -191,7 +192,8 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list, verbose=False
ch = np.argmin(cc[np.arange(n_tot), tij][_sort].reshape(-1, n_ind).sum(1))
other_ch = np.asarray([i for i in range(n_ind) if i != ch])
dep = [num_ind_dep(i, ch, n_ind) if i > ch else num_ind_dep(ch, i, n_ind) for i in other_ch]
t_dep = truncated_argmax(cc[dep], t_max)
t_dep = np.argmax(cc[dep], 1)
t_dep[t_dep > t_max] -= cc.shape[1]
tij[other_ch] = (-1)**(other_ch < ch) * t_dep + tij[ch]
dep_tdoa(tij, n_ind, n_tot)
if np.all(np.abs(tij) <= t_max):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment