From b03bb8dc1a32a046f38abb8f9414ee333265c01f Mon Sep 17 00:00:00 2001 From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr> Date: Wed, 26 Feb 2020 21:33:24 +0100 Subject: [PATCH] Starting tests --- multiview-machine-learning-omis.iml | 2 +- .../mono_multi_view_classifiers/utils/base.py | 103 ++++++++++++++++-- .../tests/test_utils/test_base.py | 10 ++ 3 files changed, 106 insertions(+), 9 deletions(-) create mode 100644 multiview_platform/tests/test_utils/test_base.py diff --git a/multiview-machine-learning-omis.iml b/multiview-machine-learning-omis.iml index 63ce5ebc..af5fb2b5 100644 --- a/multiview-machine-learning-omis.iml +++ b/multiview-machine-learning-omis.iml @@ -4,11 +4,11 @@ <content url="file://$MODULE_DIR$" /> <orderEntry type="inheritedJdk" /> <orderEntry type="sourceFolder" forTests="false" /> - <orderEntry type="module" module-name="multiviewmetriclearning" /> <orderEntry type="module" module-name="multiview_generator" /> <orderEntry type="module" module-name="short_projects" /> <orderEntry type="library" name="R User Library" level="project" /> <orderEntry type="library" name="R Skeletons" level="application" /> <orderEntry type="module" module-name="Datasets" /> + <orderEntry type="module" module-name="scikit-multimodallearn" /> </component> </module> \ No newline at end of file diff --git a/multiview_platform/mono_multi_view_classifiers/utils/base.py b/multiview_platform/mono_multi_view_classifiers/utils/base.py index e2f1c5bb..6c2dd88e 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/base.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/base.py @@ -5,6 +5,7 @@ from datetime import timedelta as hms from multiview_platform.mono_multi_view_classifiers import metrics + class BaseClassifier(BaseEstimator, ): def gen_best_params(self, detector): @@ -42,17 +43,27 @@ class BaseClassifier(BaseEstimator, ): zip(self.param_names, self.distribs)) def params_to_string(self): + """ + Formats the parameters of the classifier as a string + """ return ", ".join( [param_name + " : " + self.to_str(param_name) for param_name in self.param_names]) def get_config(self): + """ + Generates a string to containing all the information about the + classifier's configuration + """ if self.param_names: return self.__class__.__name__ + " with " + self.params_to_string() else: return self.__class__.__name__ + "with no config." def to_str(self, param_name): + """ + Formats a parameter into a string + """ if param_name in self.weird_strings: if self.weird_strings[param_name] == "class_name": return self.get_params()[param_name].__class__.__name__ @@ -63,13 +74,23 @@ class BaseClassifier(BaseEstimator, ): return str(self.get_params()[param_name]) def get_interpretation(self, directory, y_test, multi_class=False): + """ + Base method that returns an empty string if there is not interpretation + method in the classifier's module + """ return "" def accepts_multi_class(self, random_state, n_samples=10, dim=2, n_classes=3): + """ + Base function to test if the classifier accepts a multiclass task. + It is highly recommended to overwrite it with a simple method that + returns True or False in the classifier's module, as it will speed up + the benchmark + """ if int(n_samples / n_classes) < 1: raise ValueError( - "n_samples ({}) / n_classe ({}) must be over 1".format( + "n_samples ({}) / n_class ({}) must be over 1".format( n_samples, n_classes)) if hasattr(self, "accepts_mutli_class"): @@ -95,6 +116,9 @@ def get_names(classed_list): def get_metric(metric_list): + """ + Fetches the metric module in the metrics package + """ metric_module = getattr(metrics, metric_list[0][0]) if metric_list[0][1] is not None: metric_kwargs = dict((index, metricConfig) for index, metricConfig in @@ -103,12 +127,50 @@ def get_metric(metric_list): metric_kwargs = {} return metric_module, metric_kwargs + class ResultAnalyser(): + """ + A shared result analysis tool for mono and multiview classifiers. + The main utility of this class is to generate a txt file summarizing + the results and possible interpretation for the classifier. + """ def __init__(self, classifier, classification_indices, k_folds, hps_method, metrics_list, n_iter, class_label_names, train_pred, test_pred, directory, labels, database_name, nb_cores, duration): + """ + + Parameters + ---------- + classifier: estimator used for classification + + classification_indices: list of indices for train test sets + + k_folds: the sklearn StratifiedkFolds object + + hps_method: string naming the hyper-parameter search method + + metrics_list: list of the metrics to compute on the results + + n_iter: number of HPS iterations + + class_label_names: list of the names of the labels + + train_pred: classifier's prediction on the training set + + test_pred: classifier's prediction on the testing set + + directory: directory where to save the result analysis + + labels: the full labels array (Y in sklearn) + + database_name: the name of the database + + nb_cores: number of cores/threads use for the classification + + duration: duration of the classification + """ self.classifier = classifier self.train_indices, self.test_indices = classification_indices self.k_folds = k_folds @@ -127,23 +189,29 @@ class ResultAnalyser(): self.metric_scores = {} def get_all_metrics_scores(self, ): + """ + Get the scores for all the metrics in the list + Returns + ------- + """ for metric, metric_args in self.metrics_list: self.metric_scores[metric] = self.get_metric_scores(metric, metric_args) def get_metric_scores(self, metric, metric_kwargs): """ + Get the train and test scores for a specific metric and its arguments Parameters ---------- - metric : + metric : name of the metric, must be implemented in metrics - metric_kwargs : + metric_kwargs : the dictionary containing the arguments for the metric. Returns ------- - list of [train_score, test_score] + train_score, test_score """ metric_module = getattr(metrics, metric) train_score = metric_module.score(y_true=self.labels[self.train_indices], @@ -156,17 +224,17 @@ class ResultAnalyser(): def print_metric_score(self,): """ - this function print the metrics scores + Generates a string, formatting the metrics configuration and scores Parameters ---------- - metric_scores : the score of metrics + metric_scores : dictionary of train_score, test_score for each metric metric_list : list of metrics Returns ------- - metric_score_string string containing all metric results + metric_score_string string formatting all metric results """ metric_score_string = "\n\n" for metric, metric_kwargs in self.metrics_list: @@ -188,13 +256,14 @@ class ResultAnalyser(): def get_db_config_string(self,): """ + Generates a string, formatting all the information on the database Parameters ---------- Returns ------- - + db_config_string string, formatting all the information on the database """ learning_ratio = len(self.train_indices) / ( len(self.train_indices) + len(self.test_indices)) @@ -208,6 +277,13 @@ class ResultAnalyser(): return db_config_string def get_classifier_config_string(self, ): + """ + Formats the information about the classifier and its configuration + + Returns + ------- + A string explaining the classifier's configuration + """ classifier_config_string = "Classifier configuration : \n" classifier_config_string += "\t- " + self.classifier.get_config()+ "\n" classifier_config_string += "\t- Executed on {} core(s) \n".format( @@ -218,6 +294,17 @@ class ResultAnalyser(): return classifier_config_string def analyze(self, ): + """ + Main function used in the monoview and multiview classification scripts + + Returns + ------- + string_analysis : a string that will be stored in the log and in a txt + file + image_analysis : a list of images to save + metric_scores : a dictionary of {metric: (train_score, test_score)} + used in later analysis. + """ string_analysis = self.get_base_string() string_analysis += self.get_db_config_string() string_analysis += self.get_classifier_config_string() diff --git a/multiview_platform/tests/test_utils/test_base.py b/multiview_platform/tests/test_utils/test_base.py new file mode 100644 index 00000000..c4cd5998 --- /dev/null +++ b/multiview_platform/tests/test_utils/test_base.py @@ -0,0 +1,10 @@ +# import os +# import unittest +# import yaml +# import numpy as np +# +# from multiview_platform.tests.utils import rm_tmp, tmp_path +# from multiview_platform.mono_multi_view_classifiers.utils import base +# +# +# class Test_ResultAnalyzer(unittest.TestCase): -- GitLab