From d2326aa71ed25c8e5dc6ffe89494b9c00dc10f90 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Thu, 9 Mar 2023 12:40:32 -0500
Subject: [PATCH] Feat_imp_works with nans

---
 .../monoview_classifiers/ib_svm_rbf.py        |  3 +-
 .../monoview_classifiers/random_foscm.py      | 45 +++++++++++++++++++
 2 files changed, 47 insertions(+), 1 deletion(-)
 create mode 100644 summit/multiview_platform/monoview_classifiers/random_foscm.py

diff --git a/summit/multiview_platform/monoview_classifiers/ib_svm_rbf.py b/summit/multiview_platform/monoview_classifiers/ib_svm_rbf.py
index 095bc7dc..5f6305e5 100644
--- a/summit/multiview_platform/monoview_classifiers/ib_svm_rbf.py
+++ b/summit/multiview_platform/monoview_classifiers/ib_svm_rbf.py
@@ -32,7 +32,8 @@ class IBSVMRBF(BaseMonoviewClassifier, BalancedBaggingClassifier):
         for estim in self.estimators_:
             if hasattr(estim['classifier'], 'feature_importances_'):
                 self.feature_importances_ += estim['classifier'].feature_importances_
-        self.feature_importances_ /= np.sum(self.feature_importances_)
+        if np.sum(self.feature_importances_)!=0:
+            self.feature_importances_ /= np.sum(self.feature_importances_)
         return self
 
 
diff --git a/summit/multiview_platform/monoview_classifiers/random_foscm.py b/summit/multiview_platform/monoview_classifiers/random_foscm.py
new file mode 100644
index 00000000..202ac61b
--- /dev/null
+++ b/summit/multiview_platform/monoview_classifiers/random_foscm.py
@@ -0,0 +1,45 @@
+from sklearn.ensemble import RandomForestClassifier
+
+from ..monoview.monoview_utils import BaseMonoviewClassifier
+from summit.multiview_platform.utils.hyper_parameter_search import CustomRandint
+
+# Author-Info
+__author__ = "Baptiste Bauvin"
+__status__ = "Prototype"  # Production, Development, Prototype
+
+classifier_class_name = "RandomForest"
+
+
+class RandomForest(RandomForestClassifier, BaseMonoviewClassifier):
+    """
+    This class is an adaptation of scikit-learn's `RandomForestClassifier <https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html>`_
+
+
+    """
+
+    def __init__(self, random_state=None, n_estimators=10,
+                 max_depth=None, criterion='gini', **kwargs):
+
+        RandomForestClassifier.__init__(self,
+                                        n_estimators=n_estimators,
+                                        max_depth=max_depth,
+                                        criterion=criterion,
+                                        class_weight="balanced",
+                                        random_state=random_state
+                                        )
+        self.param_names = ["n_estimators", "max_depth", "criterion",
+                            "random_state"]
+        self.classed_params = []
+        self.distribs = [CustomRandint(low=1, high=300),
+                         CustomRandint(low=1, high=10),
+                         ["gini", "entropy"], [random_state]]
+        self.weird_strings = {}
+
+    def get_interpretation(self, directory, base_file_name, y_test, feature_ids,
+                           multiclass=False):
+
+        interpret_string = ""
+        interpret_string += self.get_feature_importance(directory,
+                                                        base_file_name,
+                                                        feature_ids)
+        return interpret_string
-- 
GitLab