From 8743631b2e63393d2d8b86bbf9dbf2bdaf9f7b8e Mon Sep 17 00:00:00 2001 From: bbauvin <baptiste.bauvin@centrale-marseille.fr> Date: Fri, 3 Nov 2017 10:48:37 -0400 Subject: [PATCH] Full tests for metrics scores --- .travis.yml | 4 +- .../Metrics/f1_score.py | 2 +- Code/Tests/test_ExecClassif.py | 42 +++++++++++++++++++ 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 897b933b..ce8e60a1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,8 +10,8 @@ addons: - gfortran install: - pip install -U pip pip-tools -- pip install numpy scipy scikit-learn matplotlib logging joblib argparse h5py -- git clone https://github.com/aldro61/pyscm.git /tmp/pyscm && cd /tmp/pyscm/ && python setup.py install && cd ~/babau1/multiview-machine-learning-omis +- pip install numpy scipy scikit-learn matplotlib joblib argparse h5py +- cd .. && git clone https://github.com/aldro61/pyscm.git && cd pyscm/ && python setup.py install && cd ../multiview-machine-learning-omis script: - python -m unittest discover diff --git a/Code/MonoMultiViewClassifiers/Metrics/f1_score.py b/Code/MonoMultiViewClassifiers/Metrics/f1_score.py index ab33833e..467420b4 100644 --- a/Code/MonoMultiViewClassifiers/Metrics/f1_score.py +++ b/Code/MonoMultiViewClassifiers/Metrics/f1_score.py @@ -27,7 +27,7 @@ def score(y_true, y_pred, **kwargs): try: average = kwargs["3"] except: - average = "binary" + average = "micro" score = metric(y_true, y_pred, sample_weight=sample_weight, labels=labels, pos_label=pos_label, average=average) return score diff --git a/Code/Tests/test_ExecClassif.py b/Code/Tests/test_ExecClassif.py index 7ee3811f..8930a29d 100644 --- a/Code/Tests/test_ExecClassif.py +++ b/Code/Tests/test_ExecClassif.py @@ -205,6 +205,7 @@ class Test_genMetricsScores(unittest.TestCase): @classmethod def setUpClass(cls): cls.multiclass_labels = np.array([0,1,2,3,4,5,2,1,3]) + cls.wrong_labels = np.array([1,3,3,4,5,0,2,4,3]) cls.multiclassResults = [{"chicken_is_heaven": {"labels": cls.multiclass_labels}}] cls.true_labels = np.array([0,2,2,3,4,5,1,3,2]) @@ -215,6 +216,47 @@ class Test_genMetricsScores(unittest.TestCase): multiclassResults = ExecClassif.genMetricsScores(cls.multiclassResults, cls.true_labels, cls.metrics) cls.assertEqual(cls.score_to_get, multiclassResults[0]["chicken_is_heaven"]["metricsScores"]["accuracy_score"]) + def test_multiple_clf(cls): + cls.multiclassResults = [{"chicken_is_heaven": {"labels": cls.multiclass_labels}, + "cheese_is_no_disease": {"labels": cls.wrong_labels}}, + ] + multiclassResults = ExecClassif.genMetricsScores(cls.multiclassResults, cls.true_labels, cls.metrics) + cls.assertEqual(0, multiclassResults[0]["cheese_is_no_disease"]["metricsScores"]["accuracy_score"]) + cls.assertEqual(cls.score_to_get, multiclassResults[0]["chicken_is_heaven"]["metricsScores"]["accuracy_score"]) + + def test_multiple_metrics(cls): + from sklearn.metrics import f1_score + cls.score_to_get_f1 = f1_score(cls.true_labels, cls.multiclass_labels, average="micro") + cls.metrics = [["accuracy_score"], ["f1_score"]] + multiclassResults = ExecClassif.genMetricsScores(cls.multiclassResults, cls.true_labels, cls.metrics) + cls.assertEqual(cls.score_to_get, multiclassResults[0]["chicken_is_heaven"]["metricsScores"]["accuracy_score"]) + cls.assertEqual(cls.score_to_get_f1, multiclassResults[0]["chicken_is_heaven"]["metricsScores"]["f1_score"]) + + def test_multiple_iterations(cls): + cls.multiclassResults = [{"chicken_is_heaven": {"labels": cls.multiclass_labels}}, + {"chicken_is_heaven": {"labels": cls.wrong_labels}}, + ] + multiclassResults = ExecClassif.genMetricsScores(cls.multiclassResults, cls.true_labels, cls.metrics) + cls.assertEqual(0, multiclassResults[1]["chicken_is_heaven"]["metricsScores"]["accuracy_score"]) + cls.assertEqual(cls.score_to_get, multiclassResults[0]["chicken_is_heaven"]["metricsScores"]["accuracy_score"]) + + def test_all(cls): + cls.multiclassResults = [{"chicken_is_heaven": {"labels": cls.multiclass_labels}, + "cheese_is_no_disease": {"labels": cls.wrong_labels}}, + {"chicken_is_heaven": {"labels": cls.wrong_labels}, + "cheese_is_no_disease": {"labels": cls.multiclass_labels}}, + ] + cls.metrics = [["accuracy_score"], ["f1_score"]] + from sklearn.metrics import f1_score + cls.score_to_get_f1 = f1_score(cls.true_labels, cls.multiclass_labels, average="micro") + multiclassResults = ExecClassif.genMetricsScores(cls.multiclassResults, cls.true_labels, cls.metrics) + cls.assertEqual(0, multiclassResults[1]["chicken_is_heaven"]["metricsScores"]["accuracy_score"]) + cls.assertEqual(cls.score_to_get_f1, multiclassResults[1]["cheese_is_no_disease"]["metricsScores"]["f1_score"]) + + # {}, + # {"cheese_is_no_disease": {"labels": cls.multiclass_labels}}} +# {{{"chicken_is_heaven": {"labels": cls.wrong_labels}}, +# {}}} -- GitLab