From 92ef1080a0e84b3977e66c8bf53a4a69d9cc5492 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Mon, 21 Mar 2022 09:24:50 -0400
Subject: [PATCH] Added featrue_ids

---
 summit/multiview_platform/monoview_classifiers/adaboost.py   | 5 +++--
 .../monoview_classifiers/gradient_boosting.py                | 5 +++--
 .../multiview_platform/monoview_classifiers/random_forest.py | 5 +++--
 3 files changed, 9 insertions(+), 6 deletions(-)

diff --git a/summit/multiview_platform/monoview_classifiers/adaboost.py b/summit/multiview_platform/monoview_classifiers/adaboost.py
index cd8ce3db..412e9a19 100644
--- a/summit/multiview_platform/monoview_classifiers/adaboost.py
+++ b/summit/multiview_platform/monoview_classifiers/adaboost.py
@@ -64,11 +64,12 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier):
             [step_pred for step_pred in self.staged_predict(X)])
         return pred
 
-    def get_interpretation(self, directory, base_file_name, y_test,
+    def get_interpretation(self, directory, base_file_name, y_test, feature_ids,
                            multi_class=False):  # pragma: no cover
         interpretString = ""
         interpretString += self.get_feature_importance(directory,
-                                                       base_file_name)
+                                                       base_file_name,
+                                                       feature_ids)
         interpretString += "\n\n Estimator error | Estimator weight\n"
         interpretString += "\n".join(
             [str(error) + " | " + str(weight / sum(self.estimator_weights_)) for
diff --git a/summit/multiview_platform/monoview_classifiers/gradient_boosting.py b/summit/multiview_platform/monoview_classifiers/gradient_boosting.py
index e242dee8..77242502 100644
--- a/summit/multiview_platform/monoview_classifiers/gradient_boosting.py
+++ b/summit/multiview_platform/monoview_classifiers/gradient_boosting.py
@@ -76,14 +76,15 @@ class GradientBoosting(GradientBoostingClassifier, BaseMonoviewClassifier):
                 [step_pred for step_pred in self.staged_predict(X)])
         return pred
 
-    def get_interpretation(self, directory, base_file_name, y_test,
+    def get_interpretation(self, directory, base_file_name, y_test, feature_ids,
                            multi_class=False):
         interpretString = ""
         if multi_class:
             return interpretString
         else:
             interpretString += self.get_feature_importance(directory,
-                                                           base_file_name)
+                                                           base_file_name,
+                                                           feature_ids)
             step_test_metrics = np.array(
                 [self.plotted_metric.score(y_test, step_pred) for step_pred in
                  self.step_predictions])
diff --git a/summit/multiview_platform/monoview_classifiers/random_forest.py b/summit/multiview_platform/monoview_classifiers/random_forest.py
index c0ebaaa5..f0d3578c 100644
--- a/summit/multiview_platform/monoview_classifiers/random_forest.py
+++ b/summit/multiview_platform/monoview_classifiers/random_forest.py
@@ -34,10 +34,11 @@ class RandomForest(RandomForestClassifier, BaseMonoviewClassifier):
                          ["gini", "entropy"], [random_state]]
         self.weird_strings = {}
 
-    def get_interpretation(self, directory, base_file_name, y_test,
+    def get_interpretation(self, directory, base_file_name, y_test, feature_ids,
                            multiclass=False):
 
         interpret_string = ""
         interpret_string += self.get_feature_importance(directory,
-                                                        base_file_name)
+                                                        base_file_name,
+                                                        feature_ids)
         return interpret_string
-- 
GitLab