diff --git a/gsrp_smart_util.py b/gsrp_smart_util.py index d7ac8ab9b2ad803a71a6303a4d786b63889ad6dd..ba9c0a63900f8ee7cb60636472beae5c95d2384a 100644 --- a/gsrp_smart_util.py +++ b/gsrp_smart_util.py @@ -126,15 +126,15 @@ def constrained_argmax(mem, cc, tij_ind, curr_tij, used_tij, t_max, n_ind): def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list): memory = dict() - val = 0 + val = cc[:, 0].prod() tij = np.zeros(n_tot, int) for i, step in enumerate(program): # increase dimensions for j, op in enumerate(step): if op == 'mem': if i == 0: - memory[(0, j)] = (np.concatenate((cc[op.right, -t_max:], cc[op.right, :t_max + 1])), - np.arange(-t_max, t_max + 1)[np.newaxis]) + memory[(0, j)] = mask_val((np.concatenate((cc[op.right, -t_max:], cc[op.right, :t_max + 1])), + np.arange(-t_max, t_max + 1)[np.newaxis]), val) else: memory[(i, j)] = mask_val(memory[(op.left, op.right)], val) else: # op == 'mul' @@ -150,7 +150,7 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list): done_tij.update(curr_tij) dep_tdoa(tij, n_ind, n_tot) val = max(cc[np.arange(n_tot), tij].prod(), val) - except ValueError: # search for potential maxima outside of possible values + except ValueError: # search of potential maxima ended outside of possible values tij_min, tij_max = memory[(i, j)][1].min(), memory[(i, j)][1].max() for k in range(j): memory[(i, k)] = mask_lim(memory[(i, k)], tij_min, tij_max, t_max)