From 540ea2544b7445c88b1744a303ee5790acafea1f Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Fri, 24 Mar 2023 07:17:46 -0400
Subject: [PATCH] test

---
 summit/multiview_platform/exec_classif.py     | 14 +++----------
 .../result_analysis/error_analysis.py         | 21 +++++++++++++------
 .../result_analysis/feature_importances.py    |  1 +
 3 files changed, 19 insertions(+), 17 deletions(-)

diff --git a/summit/multiview_platform/exec_classif.py b/summit/multiview_platform/exec_classif.py
index 73b87d77..09f708aa 100644
--- a/summit/multiview_platform/exec_classif.py
+++ b/summit/multiview_platform/exec_classif.py
@@ -955,7 +955,6 @@ class Summit(BaseExec):
                            metrics=self.metrics)
         return plif
 
-
     def gen_single_multiview_arg_dictionary(self, classifier_name, arguments, nb_class,
                                             hps_kwargs, views_dictionary=None,):
         if classifier_name in arguments:
@@ -975,10 +974,7 @@ class Summit(BaseExec):
                            database_name=self.name,
                            hps_type=self.hps_type,
                            nb_cores=self.nb_cores,
-                           metrics=self.metrics,
-
-                           )
-
+                           metrics=self.metrics,)
 
     def extract_dict(self, classifier_config):
         """Reverse function of get_path_dict"""
@@ -987,7 +983,6 @@ class Summit(BaseExec):
             extracted_dict = self.set_element(extracted_dict, key, value)
         return extracted_dict
 
-
     def set_element(self, dictionary, path, value):
         """Set value in dictionary at the location indicated by path"""
         existing_keys = path.split(".")[:-1]
@@ -1001,7 +996,6 @@ class Summit(BaseExec):
         dict_state[path.split(".")[-1]] = value
         return dictionary
 
-
     def get_path_dict(self, multiview_classifier_args):
         """This function is used to generate a dictionary with each key being
         the path to the value.
@@ -1018,7 +1012,6 @@ class Summit(BaseExec):
             paths = self.is_dict_in(path_dict)
         return path_dict
 
-
     def is_dict_in(self, dictionary):
         """
         Returns True if any of the dictionary value is a dictionary itself.
@@ -1037,7 +1030,6 @@ class Summit(BaseExec):
                 paths.append(key)
         return paths
 
-
     def init_kwargs(self, classifiers_names, framework="monoview"):
         r"""Used to init kwargs thanks to a function in each monoview classifier package.
 
@@ -1068,8 +1060,8 @@ class Summit(BaseExec):
                         getattr(multiview_classifiers, classifiers_name)
             except AttributeError:
                 raise AttributeError(
-                    classifiers_name + " is not implemented in monoview_classifiers, "
-                                       "please specify the name of the file in monoview_classifiers")
+                    classifiers_name + " is not implemented in {}_classifiers, "
+                                       "please specify the name of the file in {}_classifiers".format(framework, framework))
             if classifiers_name in self.args:
                 kwargs[classifiers_name] = self.args[classifiers_name]
             else:
diff --git a/summit/multiview_platform/result_analysis/error_analysis.py b/summit/multiview_platform/result_analysis/error_analysis.py
index 67240446..96dbc9bb 100644
--- a/summit/multiview_platform/result_analysis/error_analysis.py
+++ b/summit/multiview_platform/result_analysis/error_analysis.py
@@ -55,9 +55,13 @@ def publish_sample_errors(sample_errors, directory, database_name,
     nb_classifiers, nb_samples, classifiers_names, \
         data_2d, error_on_samples = gen_error_data(sample_errors)
 
-    np.savetxt(base_file_name + "2D_plot_data.csv", data_2d, delimiter=",")
-    np.savetxt(base_file_name + "bar_plot_data.csv", error_on_samples,
-               delimiter=",")
+    heat_map_data = pd.DataFrame(index=sample_ids, columns=classifiers_names, data=data_2d)
+    bar_plot_data = pd.DataFrame(index=sample_ids, data=error_on_samples)
+    heat_map_data.to_csv(base_file_name + "2D_plot_data.csv")
+    bar_plot_data.to_csv(base_file_name + "bar_plot_data.csv")
+    # np.savetxt(base_file_name + "2D_plot_data.csv", data_2d, delimiter=",")
+    # np.savetxt(base_file_name + "bar_plot_data.csv", error_on_samples,
+    #            delimiter=",")
 
     plot_2d(data_2d, classifiers_names, nb_classifiers, base_file_name, database_name,
             sample_ids=sample_ids, labels=labels, label_names=label_names, test=test)
@@ -82,9 +86,14 @@ def publish_all_sample_errors(iter_results, directory,
         add='t'
     else:
         add = ""
-    np.savetxt(os.path.join(directory, "clf_errors{}.csv".format(add)), data, delimiter=",")
-    np.savetxt(os.path.join(directory, "sample_errors{}.csv".format(add)), error_on_samples,
-               delimiter=",")
+    heat_map_data = pd.DataFrame(index=sample_ids, columns=classifier_names,
+                                 data=data)
+    bar_plot_data = pd.DataFrame(index=sample_ids, data=error_on_samples)
+    heat_map_data.to_csv(os.path.join(directory, "clf_errors{}.csv".format(add)))
+    bar_plot_data.to_csv(os.path.join(directory, "sample_errors{}.csv".format(add)))
+    # np.savetxt(os.path.join(directory, "clf_errors{}.csv".format(add)), data, delimiter=",")
+    # np.savetxt(os.path.join(directory, "sample_errors{}.csv".format(add)), error_on_samples,
+    #            delimiter=",")
     df = pd.DataFrame(index = sample_ids, columns=["err"], data=1-error_on_samples)
     df.to_csv(os.path.join(directory, "sample_err_df{}.csv".format(add)))
     plot_2d(data, classifier_names, nb_classifiers,
diff --git a/summit/multiview_platform/result_analysis/feature_importances.py b/summit/multiview_platform/result_analysis/feature_importances.py
index ec233dc2..ad92c6cd 100644
--- a/summit/multiview_platform/result_analysis/feature_importances.py
+++ b/summit/multiview_platform/result_analysis/feature_importances.py
@@ -25,6 +25,7 @@ def get_feature_importances(result, feature_ids=None, view_names=None,):
     """
     feature_importances = {}
     for classifier_result in result:
+        print((classifier_result.classifier_name))
         if isinstance(classifier_result, MonoviewResult):
             if classifier_result.view_name not in feature_importances:
                 feature_importances[classifier_result.view_name] = pd.DataFrame(
-- 
GitLab