Skip to content
Snippets Groups Projects
Commit 782be818 authored by ferrari's avatar ferrari
Browse files

Speed up smart mode for noise signals

parent 4b2cec2d
No related branches found
No related tags found
No related merge requests found
...@@ -126,15 +126,15 @@ def constrained_argmax(mem, cc, tij_ind, curr_tij, used_tij, t_max, n_ind): ...@@ -126,15 +126,15 @@ def constrained_argmax(mem, cc, tij_ind, curr_tij, used_tij, t_max, n_ind):
def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list): def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list):
memory = dict() memory = dict()
val = 0 val = cc[:, 0].prod()
tij = np.zeros(n_tot, int) tij = np.zeros(n_tot, int)
for i, step in enumerate(program): for i, step in enumerate(program):
# increase dimensions # increase dimensions
for j, op in enumerate(step): for j, op in enumerate(step):
if op == 'mem': if op == 'mem':
if i == 0: if i == 0:
memory[(0, j)] = (np.concatenate((cc[op.right, -t_max:], cc[op.right, :t_max + 1])), memory[(0, j)] = mask_val((np.concatenate((cc[op.right, -t_max:], cc[op.right, :t_max + 1])),
np.arange(-t_max, t_max + 1)[np.newaxis]) np.arange(-t_max, t_max + 1)[np.newaxis]), val)
else: else:
memory[(i, j)] = mask_val(memory[(op.left, op.right)], val) memory[(i, j)] = mask_val(memory[(op.left, op.right)], val)
else: # op == 'mul' else: # op == 'mul'
...@@ -150,7 +150,7 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list): ...@@ -150,7 +150,7 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list):
done_tij.update(curr_tij) done_tij.update(curr_tij)
dep_tdoa(tij, n_ind, n_tot) dep_tdoa(tij, n_ind, n_tot)
val = max(cc[np.arange(n_tot), tij].prod(), val) val = max(cc[np.arange(n_tot), tij].prod(), val)
except ValueError: # search for potential maxima outside of possible values except ValueError: # search of potential maxima ended outside of possible values
tij_min, tij_max = memory[(i, j)][1].min(), memory[(i, j)][1].max() tij_min, tij_max = memory[(i, j)][1].min(), memory[(i, j)][1].max()
for k in range(j): for k in range(j):
memory[(i, k)] = mask_lim(memory[(i, k)], tij_min, tij_max, t_max) memory[(i, k)] = mask_lim(memory[(i, k)], tij_min, tij_max, t_max)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment