From ecf491b568f5c1cfca70298ba6518cf2164cffb1 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Wed, 27 Nov 2019 15:40:07 -0500
Subject: [PATCH] Added feature importances on coef_ compatible algorithms

---
 config_files/config_test.yml                                | 4 ++--
 .../monoview_classifiers/lasso.py                           | 1 +
 .../mono_multi_view_classifiers/monoview_classifiers/sgd.py | 2 ++
 .../monoview_classifiers/svm_linear.py                      | 6 ++++++
 4 files changed, 11 insertions(+), 2 deletions(-)

diff --git a/config_files/config_test.yml b/config_files/config_test.yml
index ecf24d31..a3d3fbc2 100644
--- a/config_files/config_test.yml
+++ b/config_files/config_test.yml
@@ -22,8 +22,8 @@ Classification:
   nb_folds: 2
   nb_class: 2
   classes:
-  type: ["multiview","monoview"]
-  algos_monoview: ["decision_tree", "adaboost"]
+  type: ["monoview"]
+  algos_monoview: ["all", ]
   algos_multiview: ["weighted_linear_early_fusion"]
   stats_iter: 2
   metrics: ["accuracy_score", "f1_score"]
diff --git a/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/lasso.py b/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/lasso.py
index 0af82bc6..c36ad031 100644
--- a/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/lasso.py
+++ b/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/lasso.py
@@ -63,6 +63,7 @@ class Lasso(LassoSK, BaseMonoviewClassifier):
         neg_y = np.copy(y)
         neg_y[np.where(neg_y == 0)] = -1
         super(Lasso, self).fit(X, neg_y)
+        self.feature_importances_ = self.coef_/np.sum(self.coef_)
         return self
 
     def predict(self, X):
diff --git a/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/sgd.py b/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/sgd.py
index b4a0e3d7..bd0d8c70 100644
--- a/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/sgd.py
+++ b/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/sgd.py
@@ -79,6 +79,8 @@ class SGD(SGDClassifier, BaseMonoviewClassifier):
         interpret_string str to interpreted
         """
         interpret_string = ""
+        import numpy as np
+        self.feature_importances_ = (self.coef_/np.sum(self.coef_)).reshape(self.coef_.shape[1])
         return interpret_string
 
 
diff --git a/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/svm_linear.py b/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/svm_linear.py
index b71c1d31..ad867f07 100644
--- a/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/svm_linear.py
+++ b/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/svm_linear.py
@@ -34,6 +34,12 @@ class SVMLinear(SVCClassifier, BaseMonoviewClassifier):
         self.param_names = ["C", "random_state"]
         self.distribs = [CustomUniform(loc=0, state=1), [random_state]]
 
+    def getInterpret(self, directory, y_test):
+        interpret_string = ""
+        import numpy as np
+        self.feature_importances_ = (self.coef_/np.sum(self.coef_)).reshape((self.coef_.shape[1],))
+        return interpret_string
+
 
 # def formatCmdArgs(args):
 #     """Used to format kwargs for the parsed args"""
-- 
GitLab