From 336421c5fa0c55504cca4f045c7a55571b2ee023 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Thu, 14 Feb 2019 07:10:54 -0500
Subject: [PATCH] Adjustments on training curve with noise

---
 .../MonoMultiViewClassifiers/Monoview/MonoviewUtils.py   | 2 +-
 .../MonoviewClassifiers/Adaboost.py                      | 9 +++++++++
 2 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/multiview_platform/MonoMultiViewClassifiers/Monoview/MonoviewUtils.py b/multiview_platform/MonoMultiViewClassifiers/Monoview/MonoviewUtils.py
index 1e7fd817..0d4c7829 100644
--- a/multiview_platform/MonoMultiViewClassifiers/Monoview/MonoviewUtils.py
+++ b/multiview_platform/MonoMultiViewClassifiers/Monoview/MonoviewUtils.py
@@ -31,7 +31,7 @@ def randomizedSearch(X_train, y_train, randomState, outputFileName, classifierMo
             nIter = nb_possible_combinations
         randomSearch = RandomizedSearchCV(estimator, n_iter=nIter, param_distributions=params_dict, refit=True,
                                           n_jobs=nbCores, scoring=scorer, cv=KFolds, random_state=randomState)
-        print(estimator)
+        print(X_train)
         detector = randomSearch.fit(X_train, y_train)
 
         bestParams = estimator.genBestParams(detector)
diff --git a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/Adaboost.py b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/Adaboost.py
index 68c1b2d0..4153a0ee 100644
--- a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/Adaboost.py
+++ b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/Adaboost.py
@@ -29,6 +29,7 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier):
         self.weird_strings = {"base_estimator": "class_name"}
         self.plotted_metric = Metrics.zero_one_loss
         self.plotted_metric_name = "zero_one_loss"
+        self.step_predictions = None
 
     def fit(self, X, y, sample_weight=None):
         super(Adaboost, self).fit(X, y, sample_weight=sample_weight)
@@ -40,12 +41,20 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier):
         """Used to know if the classifier can return label probabilities"""
         return True
 
+    def predict(self, X):
+        super(Adaboost, self).predict(X)
+        self.step_predictions = np.array([step_pred for step_pred in self.staged_predict(X)])
+
     def getInterpret(self, directory, y_test):
         interpretString = ""
         interpretString += self.getFeatureImportance(directory)
         interpretString += "\n\n Estimator error | Estimator weight\n"
         interpretString += "\n".join([str(error) +" | "+ str(weight/sum(self.estimator_weights_)) for error, weight in zip(self.estimator_errors_, self.estimator_weights_)])
+        step_test_metrics = np.array([self.plotted_metric.score(y_test, step_pred) for step_pred in self.step_predictions])
+        get_accuracy_graph(step_test_metrics, "Adaboost", directory + "test_metrics.png",
+                           self.plotted_metric_name, set="test")
         get_accuracy_graph(self.metrics, "Adaboost", directory+"metrics.png", self.plotted_metric_name, bounds=list(self.bounds))
+        np.savetxt(directory + "test_metrics.csv", step_test_metrics, delimiter=',')
         np.savetxt(directory + "train_metrics.csv", self.metrics, delimiter=',')
         return interpretString
 
-- 
GitLab