From 564949c9e6300a9b519a8d44c4296236c760556e Mon Sep 17 00:00:00 2001
From: ferrari <maxence.ferrari@gmail.com>
Date: Fri, 30 Jun 2023 17:55:14 +0200
Subject: [PATCH] Increase check for floating point errors

---
 gsrp_smart_util.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/gsrp_smart_util.py b/gsrp_smart_util.py
index 1096884..67a3463 100644
--- a/gsrp_smart_util.py
+++ b/gsrp_smart_util.py
@@ -208,7 +208,11 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list, verbose=False
             memory[k] = mask_val(v, out_val - sum(m for tij, m in max_val.items() if curr_tij.isdisjoint(tij)))
         if i and np.prod([v[0].size for v in memory.values()], dtype=np.float64) <= 1024:  # float to prevent overflow
             try:
-                return *add_all(memory, tree, cc, t_max, n_ind, n_tot), 1
+                val, tij = add_all(memory, tree, cc, t_max, n_ind, n_tot)
+                if val > out_val:  # is false is val == out_val, or if floating point error mask potential good results
+                    return cc[np.arange(n_tot), tij].sum(), tij, 1  # recomputing to reduce floating point errors
+                else:
+                    return out_val, tij, 1
             except ValueError as e:
                 if any(v[0].size == 0 for v in memory.values()):  # due floating point error
                     return out_val, out_tij, 1  # current tij should only contain maxima for this error to occur
-- 
GitLab