diff --git a/gsrp_smart_util.py b/gsrp_smart_util.py index 1096884311ad2ebc2da1e9310d1b4d72a4406651..67a3463c80c8a0aedd9e7b1c79dec2c3518df49b 100644 --- a/gsrp_smart_util.py +++ b/gsrp_smart_util.py @@ -208,7 +208,11 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list, verbose=False memory[k] = mask_val(v, out_val - sum(m for tij, m in max_val.items() if curr_tij.isdisjoint(tij))) if i and np.prod([v[0].size for v in memory.values()], dtype=np.float64) <= 1024: # float to prevent overflow try: - return *add_all(memory, tree, cc, t_max, n_ind, n_tot), 1 + val, tij = add_all(memory, tree, cc, t_max, n_ind, n_tot) + if val > out_val: # is false is val == out_val, or if floating point error mask potential good results + return cc[np.arange(n_tot), tij].sum(), tij, 1 # recomputing to reduce floating point errors + else: + return out_val, tij, 1 except ValueError as e: if any(v[0].size == 0 for v in memory.values()): # due floating point error return out_val, out_tij, 1 # current tij should only contain maxima for this error to occur