From cf34b931a62d9497bf1ce75476432aa53467eb80 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Tue, 21 Jan 2020 14:25:35 +0100
Subject: [PATCH] Corrected cross val multiview

---
 config_files/config_test.yml                   | 18 +++++++++---------
 .../weighted_linear_early_fusion.py            |  3 ++-
 .../result_analysis.py                         | 15 ++++++++++-----
 .../utils/hyper_parameter_search.py            |  6 +++---
 4 files changed, 24 insertions(+), 18 deletions(-)

diff --git a/config_files/config_test.yml b/config_files/config_test.yml
index f54013c7..d3142c90 100644
--- a/config_files/config_test.yml
+++ b/config_files/config_test.yml
@@ -1,11 +1,11 @@
 # The base configuration of the benchmark
 Base :
   log: True
-  name: ["outliers_dset"]
+  name: ["awa-tiger-wolf-all"]
   label: "_"
   type: ".hdf5"
   views:
-  pathf: "/home/baptiste/Documents/Datasets/Generated/outliers_dset/"
+  pathf: "/home/baptiste/Documents/Datasets/AWA/base/"
   nice: 0
   random_state: 42
   nb_cores: 1
@@ -18,16 +18,16 @@ Base :
 # All the classification-realted configuration options
 Classification:
   multiclass_method: "oneVersusOne"
-  split: 0.2
+  split: 0.9
   nb_folds: 2
   nb_class: 2
   classes:
-  type: ["monoview", "multiview"]
-  algos_monoview: ["decision_tree", "adaboost", "svm_linear", "random_forest"]
-  algos_multiview: ["weighted_linear_early_fusion", "difficulty_fusion", "double_fault_fusion"]
-  stats_iter: 30
+  type: ["multiview", "monoview"]
+  algos_monoview: ["decision_tree", "adaboost", "random_forest" ]
+  algos_multiview: ["weighted_linear_early_fusion",]
+  stats_iter: 1
   metrics: ["accuracy_score", "f1_score"]
-  metric_princ: "accuracy_score"
+  metric_princ: "f1_score"
   hps_type: "randomized_search-equiv"
   hps_iter: 5
 
@@ -65,7 +65,7 @@ adaboost_graalpy:
   n_stumps: [1]
 
 decision_tree:
-  max_depth: [10]
+  max_depth: [2]
   criterion: ["gini"]
   splitter: ["best"]
 
diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py
index 484b6657..170c1010 100644
--- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py
+++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py
@@ -34,7 +34,7 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier):
         super(WeightedLinearEarlyFusion, self).__init__(random_state=random_state)
         self.view_weights = view_weights
         self.monoview_classifier_name = monoview_classifier_name
-        self.short_name = "early fusion " + monoview_classifier_name
+        self.short_name = "early fusion " + self.monoview_classifier_name
         if monoview_classifier_name in monoview_classifier_config:
             self.monoview_classifier_config = monoview_classifier_config[monoview_classifier_name]
         self.monoview_classifier_config = monoview_classifier_config
@@ -59,6 +59,7 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier):
         self.monoview_classifier = monoview_classifier_class()
         self.init_monoview_estimator(monoview_classifier_name,
                                        monoview_classifier_config)
+        self.short_name = "early fusion " + self.monoview_classifier_name
         return self
 
     def get_params(self, deep=True):
diff --git a/multiview_platform/mono_multi_view_classifiers/result_analysis.py b/multiview_platform/mono_multi_view_classifiers/result_analysis.py
index f529f1af..fdf0928d 100644
--- a/multiview_platform/mono_multi_view_classifiers/result_analysis.py
+++ b/multiview_platform/mono_multi_view_classifiers/result_analysis.py
@@ -4,6 +4,7 @@ import logging
 import os
 import time
 import yaml
+import traceback
 
 import matplotlib as mpl
 from matplotlib.patches import Patch
@@ -162,6 +163,8 @@ def plot_metric_scores(train_scores, test_scores, names, nb_results, metric_name
         ))
 
         fig.update_layout(title=metric_name + "\n" + tag + " scores for each classifier")
+        fig.update_layout(paper_bgcolor = 'rgba(0,0,0,0)',
+                          plot_bgcolor = 'rgba(0,0,0,0)')
         plotly.offline.plot(fig, filename=file_name + ".html", auto_open=False)
         del fig
 
@@ -232,7 +235,8 @@ def plot_2d(data, classifiers_names, nbClassifiers, nbExamples,
                 reversescale=True), row=row_index+1, col=1)
             fig.update_yaxes(title_text="Label "+str(row_index), showticklabels=False, ticks='', row=row_index+1, col=1)
             fig.update_xaxes(showticklabels=False, row=row_index+1, col=1)
-
+        fig.update_layout(paper_bgcolor = 'rgba(0,0,0,0)',
+                          plot_bgcolor = 'rgba(0,0,0,0)')
         fig.update_xaxes(showticklabels=True, row=len(label_index_list), col=1)
         plotly.offline.plot(fig, filename=file_name + "error_analysis_2D.html", auto_open=False)
         del fig
@@ -629,6 +633,8 @@ def publish_feature_importances(feature_importances, directory, database_name, l
         fig.update_layout(
             xaxis={"showgrid": False, "showticklabels": False, "ticks": ''},
             yaxis={"showgrid": False, "showticklabels": False, "ticks": ''})
+        fig.update_layout(paper_bgcolor = 'rgba(0,0,0,0)',
+                          plot_bgcolor = 'rgba(0,0,0,0)')
         plotly.offline.plot(fig, filename=file_name + ".html", auto_open=False)
 
         del fig
@@ -724,7 +730,7 @@ def analyze_biclass(results, benchmark_argument_dictionaries, stats_iter, metric
     logging.debug("Srart:\t Analzing all biclass resuls")
     biclass_results = {}
     flagged_tracebacks_list = []
-
+    fig_errors = []
     for flag, result, tracebacks in results:
         iteridex, [classifierPositive, classifierNegative] = flag
 
@@ -739,14 +745,13 @@ def analyze_biclass(results, benchmark_argument_dictionaries, stats_iter, metric
         labels_names = [arguments["labels_dictionary"][0],
                        arguments["labels_dictionary"][1]]
 
+        flagged_tracebacks_list += publish_tracebacks(directory, database_name, labels_names, tracebacks, flag)
         results = publishMetricsGraphs(metrics_scores, directory, database_name,
-                             labels_names)
+                                       labels_names)
         publishExampleErrors(example_errors, directory, database_name,
                              labels_names, example_ids, arguments["labels"])
         publish_feature_importances(feature_importances, directory, database_name, labels_names)
 
-        flagged_tracebacks_list += publish_tracebacks(directory, database_name, labels_names, tracebacks, flag)
-
 
         if not str(classifierPositive) + str(classifierNegative) in biclass_results:
             biclass_results[str(classifierPositive) + str(classifierNegative)] = {}
diff --git a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py
index bc27b3d3..46617c6e 100644
--- a/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py
+++ b/multiview_platform/mono_multi_view_classifiers/utils/hyper_parameter_search.py
@@ -180,12 +180,12 @@ class MultiviewCompatibleRandomizedSearchCV(RandomizedSearchCV):
 
     def fit_multiview(self, X, y=None, groups=None, **fit_params):
         n_splits = self.cv.get_n_splits(self.available_indices, y[self.available_indices])
-        folds = self.cv.split(self.available_indices, y[self.available_indices])
+        folds = list(self.cv.split(self.available_indices, y[self.available_indices]))
+        if self.equivalent_draws:
+            self.n_iter = self.n_iter*X.nb_view
         candidate_params = list(self._get_param_iterator())
         base_estimator = clone(self.estimator)
         results = {}
-        if self.equivalent_draws:
-            self.n_iter = self.n_iter*X.nb_view
         self.cv_results_ = dict(("param_"+param_name, []) for param_name in candidate_params[0].keys())
         self.cv_results_["mean_test_score"] = []
         for candidate_param_idx, candidate_param in enumerate(candidate_params):
-- 
GitLab