From 2c11235e9ada682c7789d99553bfb747f9a478d9 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Thu, 5 Mar 2020 16:42:14 +0100
Subject: [PATCH] Tracebacks included

---
 config_files/config_test.yml                                   | 2 +-
 .../utils/hyper_parameter_search.py                            | 3 ++-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/config_files/config_test.yml b/config_files/config_test.yml
index 64cb1273..0a83c992 100644
--- a/config_files/config_test.yml
+++ b/config_files/config_test.yml
@@ -13,7 +13,7 @@ debug: True
 add_noise: False
 noise_std: 0.0
 res_dir: "../results/"
-track_tracebacks: False
+track_tracebacks: True
 
 # All the classification-realted configuration options
 multiclass_method: "oneVersusOne"
diff --git a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py
index 8132b9ab..734092cd 100644
--- a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py
+++ b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py
@@ -65,7 +65,6 @@ class HPSearch:
         self.cv_results_["mean_test_score"] = []
         self.cv_results_["params"] = []
         n_failed = 0
-        self.tracebacks = []
         self.tracebacks_params = []
         for candidate_param_idx, candidate_param in enumerate(self.candidate_params):
             test_scores = np.zeros(n_splits) + 1000
@@ -164,6 +163,7 @@ class Random(RandomizedSearchCV, HPSearch):
         self.view_indices = view_indices
         self.equivalent_draws = equivalent_draws
         self.track_tracebacks = track_tracebacks
+        self.tracebacks=[]
 
     def get_param_distribs(self, estimator):
         if isinstance(estimator, MultiClassWrapper):
@@ -208,6 +208,7 @@ class Grid(GridSearchCV, HPSearch):
         self.available_indices = learning_indices
         self.view_indices = view_indices
         self.track_tracebacks = track_tracebacks
+        self.tracebacks = []
 
     def fit(self, X, y=None, groups=None, **fit_params):
         if self.framework == "monoview":
-- 
GitLab