From 32f67aa745d7d5b0cd95332bc2f2fd17e42342a0 Mon Sep 17 00:00:00 2001
From: ferrari <maxence.ferrari@gmail.com>
Date: Wed, 22 Feb 2023 16:12:04 +0100
Subject: [PATCH] Fix floating point error

---
 gsrp_smart_util.py | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/gsrp_smart_util.py b/gsrp_smart_util.py
index 25e6829..0e42e8d 100644
--- a/gsrp_smart_util.py
+++ b/gsrp_smart_util.py
@@ -207,7 +207,12 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list, verbose=False
             curr_tij = frozenset(tree[k[0]][k[1]])
             memory[k] = mask_val(v, out_val - sum(m for tij, m in max_val.items() if curr_tij.isdisjoint(tij)))
         if i and np.prod([v[0].size for v in memory.values()], dtype=np.float64) <= 1024:  # float to prevent overflow
-            return *add_all(memory, tree, cc, t_max, n_ind, n_tot), 1
+            try:
+                return *add_all(memory, tree, cc, t_max, n_ind, n_tot), 1
+            except ValueError as e:
+                if any(v[0].size == 0 for v in memory.values()):  # due floating point error
+                    return out_val, out_tij, 1  # current tij should only contain maxima for this error to occur
+                raise e from None
         for j, op in enumerate(step):
             if op == 'mem':
                 if i == 0:
-- 
GitLab