From b03bb8dc1a32a046f38abb8f9414ee333265c01f Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Wed, 26 Feb 2020 21:33:24 +0100
Subject: [PATCH] Starting tests

---
 multiview-machine-learning-omis.iml           |   2 +-
 .../mono_multi_view_classifiers/utils/base.py | 103 ++++++++++++++++--
 .../tests/test_utils/test_base.py             |  10 ++
 3 files changed, 106 insertions(+), 9 deletions(-)
 create mode 100644 multiview_platform/tests/test_utils/test_base.py

diff --git a/multiview-machine-learning-omis.iml b/multiview-machine-learning-omis.iml
index 63ce5ebc..af5fb2b5 100644
--- a/multiview-machine-learning-omis.iml
+++ b/multiview-machine-learning-omis.iml
@@ -4,11 +4,11 @@
     <content url="file://$MODULE_DIR$" />
     <orderEntry type="inheritedJdk" />
     <orderEntry type="sourceFolder" forTests="false" />
-    <orderEntry type="module" module-name="multiviewmetriclearning" />
     <orderEntry type="module" module-name="multiview_generator" />
     <orderEntry type="module" module-name="short_projects" />
     <orderEntry type="library" name="R User Library" level="project" />
     <orderEntry type="library" name="R Skeletons" level="application" />
     <orderEntry type="module" module-name="Datasets" />
+    <orderEntry type="module" module-name="scikit-multimodallearn" />
   </component>
 </module>
\ No newline at end of file
diff --git a/multiview_platform/mono_multi_view_classifiers/utils/base.py b/multiview_platform/mono_multi_view_classifiers/utils/base.py
index e2f1c5bb..6c2dd88e 100644
--- a/multiview_platform/mono_multi_view_classifiers/utils/base.py
+++ b/multiview_platform/mono_multi_view_classifiers/utils/base.py
@@ -5,6 +5,7 @@ from datetime import timedelta as hms
 
 from multiview_platform.mono_multi_view_classifiers import metrics
 
+
 class BaseClassifier(BaseEstimator, ):
 
     def gen_best_params(self, detector):
@@ -42,17 +43,27 @@ class BaseClassifier(BaseEstimator, ):
                     zip(self.param_names, self.distribs))
 
     def params_to_string(self):
+        """
+        Formats the parameters of the classifier as a string
+        """
         return ", ".join(
             [param_name + " : " + self.to_str(param_name) for param_name in
              self.param_names])
 
     def get_config(self):
+        """
+        Generates a string to containing all the information about the
+        classifier's configuration
+        """
         if self.param_names:
             return self.__class__.__name__ + " with " + self.params_to_string()
         else:
             return self.__class__.__name__ + "with no config."
 
     def to_str(self, param_name):
+        """
+        Formats a parameter into a string
+        """
         if param_name in self.weird_strings:
             if self.weird_strings[param_name] == "class_name":
                 return self.get_params()[param_name].__class__.__name__
@@ -63,13 +74,23 @@ class BaseClassifier(BaseEstimator, ):
             return str(self.get_params()[param_name])
 
     def get_interpretation(self, directory, y_test, multi_class=False):
+        """
+        Base method that returns an empty string if there is not interpretation
+        method in the classifier's module
+        """
         return ""
 
     def accepts_multi_class(self, random_state, n_samples=10, dim=2,
                             n_classes=3):
+        """
+        Base function to test if the classifier accepts a multiclass task.
+        It is highly recommended to overwrite it with a simple method that
+        returns True or False in the classifier's module, as it will speed up
+        the benchmark
+        """
         if int(n_samples / n_classes) < 1:
             raise ValueError(
-                "n_samples ({}) / n_classe ({}) must be over 1".format(
+                "n_samples ({}) / n_class ({}) must be over 1".format(
                     n_samples,
                     n_classes))
         if hasattr(self, "accepts_mutli_class"):
@@ -95,6 +116,9 @@ def get_names(classed_list):
 
 
 def get_metric(metric_list):
+    """
+    Fetches the metric module in the metrics package
+    """
     metric_module = getattr(metrics, metric_list[0][0])
     if metric_list[0][1] is not None:
         metric_kwargs = dict((index, metricConfig) for index, metricConfig in
@@ -103,12 +127,50 @@ def get_metric(metric_list):
         metric_kwargs = {}
     return metric_module, metric_kwargs
 
+
 class ResultAnalyser():
+    """
+    A shared result analysis tool for mono and multiview classifiers.
+    The main utility of this class is to generate a txt file summarizing
+    the results and possible interpretation for the classifier.
+    """
 
     def __init__(self, classifier, classification_indices, k_folds,
                  hps_method, metrics_list, n_iter, class_label_names,
                  train_pred, test_pred, directory, labels, database_name,
                  nb_cores, duration):
+        """
+
+        Parameters
+        ----------
+        classifier: estimator used for classification
+
+        classification_indices: list of indices for train test sets
+
+        k_folds: the sklearn StratifiedkFolds object
+
+        hps_method: string naming the hyper-parameter search method
+
+        metrics_list: list of the metrics to compute on the results
+
+        n_iter: number of HPS iterations
+
+        class_label_names: list of the names of the labels
+
+        train_pred: classifier's prediction on the training set
+
+        test_pred: classifier's prediction on the testing set
+
+        directory: directory where to save the result analysis
+
+        labels: the full labels array (Y in sklearn)
+
+        database_name: the name of the database
+
+        nb_cores: number of cores/threads use for the classification
+
+        duration: duration of the classification
+        """
         self.classifier = classifier
         self.train_indices, self.test_indices = classification_indices
         self.k_folds = k_folds
@@ -127,23 +189,29 @@ class ResultAnalyser():
         self.metric_scores = {}
 
     def get_all_metrics_scores(self, ):
+        """
+        Get the scores for all the metrics in the list
+        Returns
+        -------
+        """
         for metric, metric_args in self.metrics_list:
             self.metric_scores[metric] = self.get_metric_scores(metric,
                                                                     metric_args)
 
     def get_metric_scores(self, metric, metric_kwargs):
         """
+        Get the train and test scores for a specific metric and its arguments
 
         Parameters
         ----------
 
-        metric :
+        metric : name of the metric, must be implemented in metrics
 
-        metric_kwargs :
+        metric_kwargs : the dictionary containing the arguments for the metric.
 
         Returns
         -------
-        list of [train_score, test_score]
+        train_score, test_score
         """
         metric_module = getattr(metrics, metric)
         train_score = metric_module.score(y_true=self.labels[self.train_indices],
@@ -156,17 +224,17 @@ class ResultAnalyser():
 
     def print_metric_score(self,):
         """
-        this function print the metrics scores
+        Generates a string, formatting the metrics configuration and scores
 
         Parameters
         ----------
-        metric_scores : the score of metrics
+        metric_scores : dictionary of train_score, test_score for each metric
 
         metric_list : list of metrics
 
         Returns
         -------
-        metric_score_string string containing all metric results
+        metric_score_string string formatting all metric results
         """
         metric_score_string = "\n\n"
         for metric, metric_kwargs in self.metrics_list:
@@ -188,13 +256,14 @@ class ResultAnalyser():
 
     def get_db_config_string(self,):
         """
+        Generates a string, formatting all the information on the database
 
         Parameters
         ----------
 
         Returns
         -------
-
+        db_config_string string, formatting all the information on the database
         """
         learning_ratio = len(self.train_indices) / (
                 len(self.train_indices) + len(self.test_indices))
@@ -208,6 +277,13 @@ class ResultAnalyser():
         return db_config_string
 
     def get_classifier_config_string(self, ):
+        """
+        Formats the information about the classifier and its configuration
+
+        Returns
+        -------
+        A string explaining the classifier's configuration
+        """
         classifier_config_string = "Classifier configuration : \n"
         classifier_config_string += "\t- " + self.classifier.get_config()+ "\n"
         classifier_config_string += "\t- Executed on {} core(s) \n".format(
@@ -218,6 +294,17 @@ class ResultAnalyser():
         return classifier_config_string
 
     def analyze(self, ):
+        """
+        Main function used in the monoview and multiview classification scripts
+
+        Returns
+        -------
+        string_analysis : a string that will be stored in the log and in a txt
+        file
+        image_analysis : a list of images to save
+        metric_scores : a dictionary of {metric: (train_score, test_score)}
+        used in later analysis.
+        """
         string_analysis = self.get_base_string()
         string_analysis += self.get_db_config_string()
         string_analysis += self.get_classifier_config_string()
diff --git a/multiview_platform/tests/test_utils/test_base.py b/multiview_platform/tests/test_utils/test_base.py
new file mode 100644
index 00000000..c4cd5998
--- /dev/null
+++ b/multiview_platform/tests/test_utils/test_base.py
@@ -0,0 +1,10 @@
+# import os
+# import unittest
+# import yaml
+# import numpy as np
+#
+# from multiview_platform.tests.utils import rm_tmp, tmp_path
+# from multiview_platform.mono_multi_view_classifiers.utils import base
+#
+#
+# class Test_ResultAnalyzer(unittest.TestCase):
-- 
GitLab