Skip to content
Snippets Groups Projects
Commit b03bb8dc authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Starting tests

parent 3cb2a6fe
Branches
No related tags found
No related merge requests found
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" /> <orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="module" module-name="multiviewmetriclearning" />
<orderEntry type="module" module-name="multiview_generator" /> <orderEntry type="module" module-name="multiview_generator" />
<orderEntry type="module" module-name="short_projects" /> <orderEntry type="module" module-name="short_projects" />
<orderEntry type="library" name="R User Library" level="project" /> <orderEntry type="library" name="R User Library" level="project" />
<orderEntry type="library" name="R Skeletons" level="application" /> <orderEntry type="library" name="R Skeletons" level="application" />
<orderEntry type="module" module-name="Datasets" /> <orderEntry type="module" module-name="Datasets" />
<orderEntry type="module" module-name="scikit-multimodallearn" />
</component> </component>
</module> </module>
\ No newline at end of file
...@@ -5,6 +5,7 @@ from datetime import timedelta as hms ...@@ -5,6 +5,7 @@ from datetime import timedelta as hms
from multiview_platform.mono_multi_view_classifiers import metrics from multiview_platform.mono_multi_view_classifiers import metrics
class BaseClassifier(BaseEstimator, ): class BaseClassifier(BaseEstimator, ):
def gen_best_params(self, detector): def gen_best_params(self, detector):
...@@ -42,17 +43,27 @@ class BaseClassifier(BaseEstimator, ): ...@@ -42,17 +43,27 @@ class BaseClassifier(BaseEstimator, ):
zip(self.param_names, self.distribs)) zip(self.param_names, self.distribs))
def params_to_string(self): def params_to_string(self):
"""
Formats the parameters of the classifier as a string
"""
return ", ".join( return ", ".join(
[param_name + " : " + self.to_str(param_name) for param_name in [param_name + " : " + self.to_str(param_name) for param_name in
self.param_names]) self.param_names])
def get_config(self): def get_config(self):
"""
Generates a string to containing all the information about the
classifier's configuration
"""
if self.param_names: if self.param_names:
return self.__class__.__name__ + " with " + self.params_to_string() return self.__class__.__name__ + " with " + self.params_to_string()
else: else:
return self.__class__.__name__ + "with no config." return self.__class__.__name__ + "with no config."
def to_str(self, param_name): def to_str(self, param_name):
"""
Formats a parameter into a string
"""
if param_name in self.weird_strings: if param_name in self.weird_strings:
if self.weird_strings[param_name] == "class_name": if self.weird_strings[param_name] == "class_name":
return self.get_params()[param_name].__class__.__name__ return self.get_params()[param_name].__class__.__name__
...@@ -63,13 +74,23 @@ class BaseClassifier(BaseEstimator, ): ...@@ -63,13 +74,23 @@ class BaseClassifier(BaseEstimator, ):
return str(self.get_params()[param_name]) return str(self.get_params()[param_name])
def get_interpretation(self, directory, y_test, multi_class=False): def get_interpretation(self, directory, y_test, multi_class=False):
"""
Base method that returns an empty string if there is not interpretation
method in the classifier's module
"""
return "" return ""
def accepts_multi_class(self, random_state, n_samples=10, dim=2, def accepts_multi_class(self, random_state, n_samples=10, dim=2,
n_classes=3): n_classes=3):
"""
Base function to test if the classifier accepts a multiclass task.
It is highly recommended to overwrite it with a simple method that
returns True or False in the classifier's module, as it will speed up
the benchmark
"""
if int(n_samples / n_classes) < 1: if int(n_samples / n_classes) < 1:
raise ValueError( raise ValueError(
"n_samples ({}) / n_classe ({}) must be over 1".format( "n_samples ({}) / n_class ({}) must be over 1".format(
n_samples, n_samples,
n_classes)) n_classes))
if hasattr(self, "accepts_mutli_class"): if hasattr(self, "accepts_mutli_class"):
...@@ -95,6 +116,9 @@ def get_names(classed_list): ...@@ -95,6 +116,9 @@ def get_names(classed_list):
def get_metric(metric_list): def get_metric(metric_list):
"""
Fetches the metric module in the metrics package
"""
metric_module = getattr(metrics, metric_list[0][0]) metric_module = getattr(metrics, metric_list[0][0])
if metric_list[0][1] is not None: if metric_list[0][1] is not None:
metric_kwargs = dict((index, metricConfig) for index, metricConfig in metric_kwargs = dict((index, metricConfig) for index, metricConfig in
...@@ -103,12 +127,50 @@ def get_metric(metric_list): ...@@ -103,12 +127,50 @@ def get_metric(metric_list):
metric_kwargs = {} metric_kwargs = {}
return metric_module, metric_kwargs return metric_module, metric_kwargs
class ResultAnalyser(): class ResultAnalyser():
"""
A shared result analysis tool for mono and multiview classifiers.
The main utility of this class is to generate a txt file summarizing
the results and possible interpretation for the classifier.
"""
def __init__(self, classifier, classification_indices, k_folds, def __init__(self, classifier, classification_indices, k_folds,
hps_method, metrics_list, n_iter, class_label_names, hps_method, metrics_list, n_iter, class_label_names,
train_pred, test_pred, directory, labels, database_name, train_pred, test_pred, directory, labels, database_name,
nb_cores, duration): nb_cores, duration):
"""
Parameters
----------
classifier: estimator used for classification
classification_indices: list of indices for train test sets
k_folds: the sklearn StratifiedkFolds object
hps_method: string naming the hyper-parameter search method
metrics_list: list of the metrics to compute on the results
n_iter: number of HPS iterations
class_label_names: list of the names of the labels
train_pred: classifier's prediction on the training set
test_pred: classifier's prediction on the testing set
directory: directory where to save the result analysis
labels: the full labels array (Y in sklearn)
database_name: the name of the database
nb_cores: number of cores/threads use for the classification
duration: duration of the classification
"""
self.classifier = classifier self.classifier = classifier
self.train_indices, self.test_indices = classification_indices self.train_indices, self.test_indices = classification_indices
self.k_folds = k_folds self.k_folds = k_folds
...@@ -127,23 +189,29 @@ class ResultAnalyser(): ...@@ -127,23 +189,29 @@ class ResultAnalyser():
self.metric_scores = {} self.metric_scores = {}
def get_all_metrics_scores(self, ): def get_all_metrics_scores(self, ):
"""
Get the scores for all the metrics in the list
Returns
-------
"""
for metric, metric_args in self.metrics_list: for metric, metric_args in self.metrics_list:
self.metric_scores[metric] = self.get_metric_scores(metric, self.metric_scores[metric] = self.get_metric_scores(metric,
metric_args) metric_args)
def get_metric_scores(self, metric, metric_kwargs): def get_metric_scores(self, metric, metric_kwargs):
""" """
Get the train and test scores for a specific metric and its arguments
Parameters Parameters
---------- ----------
metric : metric : name of the metric, must be implemented in metrics
metric_kwargs : metric_kwargs : the dictionary containing the arguments for the metric.
Returns Returns
------- -------
list of [train_score, test_score] train_score, test_score
""" """
metric_module = getattr(metrics, metric) metric_module = getattr(metrics, metric)
train_score = metric_module.score(y_true=self.labels[self.train_indices], train_score = metric_module.score(y_true=self.labels[self.train_indices],
...@@ -156,17 +224,17 @@ class ResultAnalyser(): ...@@ -156,17 +224,17 @@ class ResultAnalyser():
def print_metric_score(self,): def print_metric_score(self,):
""" """
this function print the metrics scores Generates a string, formatting the metrics configuration and scores
Parameters Parameters
---------- ----------
metric_scores : the score of metrics metric_scores : dictionary of train_score, test_score for each metric
metric_list : list of metrics metric_list : list of metrics
Returns Returns
------- -------
metric_score_string string containing all metric results metric_score_string string formatting all metric results
""" """
metric_score_string = "\n\n" metric_score_string = "\n\n"
for metric, metric_kwargs in self.metrics_list: for metric, metric_kwargs in self.metrics_list:
...@@ -188,13 +256,14 @@ class ResultAnalyser(): ...@@ -188,13 +256,14 @@ class ResultAnalyser():
def get_db_config_string(self,): def get_db_config_string(self,):
""" """
Generates a string, formatting all the information on the database
Parameters Parameters
---------- ----------
Returns Returns
------- -------
db_config_string string, formatting all the information on the database
""" """
learning_ratio = len(self.train_indices) / ( learning_ratio = len(self.train_indices) / (
len(self.train_indices) + len(self.test_indices)) len(self.train_indices) + len(self.test_indices))
...@@ -208,6 +277,13 @@ class ResultAnalyser(): ...@@ -208,6 +277,13 @@ class ResultAnalyser():
return db_config_string return db_config_string
def get_classifier_config_string(self, ): def get_classifier_config_string(self, ):
"""
Formats the information about the classifier and its configuration
Returns
-------
A string explaining the classifier's configuration
"""
classifier_config_string = "Classifier configuration : \n" classifier_config_string = "Classifier configuration : \n"
classifier_config_string += "\t- " + self.classifier.get_config()+ "\n" classifier_config_string += "\t- " + self.classifier.get_config()+ "\n"
classifier_config_string += "\t- Executed on {} core(s) \n".format( classifier_config_string += "\t- Executed on {} core(s) \n".format(
...@@ -218,6 +294,17 @@ class ResultAnalyser(): ...@@ -218,6 +294,17 @@ class ResultAnalyser():
return classifier_config_string return classifier_config_string
def analyze(self, ): def analyze(self, ):
"""
Main function used in the monoview and multiview classification scripts
Returns
-------
string_analysis : a string that will be stored in the log and in a txt
file
image_analysis : a list of images to save
metric_scores : a dictionary of {metric: (train_score, test_score)}
used in later analysis.
"""
string_analysis = self.get_base_string() string_analysis = self.get_base_string()
string_analysis += self.get_db_config_string() string_analysis += self.get_db_config_string()
string_analysis += self.get_classifier_config_string() string_analysis += self.get_classifier_config_string()
......
# import os
# import unittest
# import yaml
# import numpy as np
#
# from multiview_platform.tests.utils import rm_tmp, tmp_path
# from multiview_platform.mono_multi_view_classifiers.utils import base
#
#
# class Test_ResultAnalyzer(unittest.TestCase):
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment