diff --git a/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/adaboost.py b/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/adaboost.py index e1f23cc46dffc9c5aeac01b8b83066b3198651ca..9d36b968ad0a93ba7dd881a51192a928ef5b7e04 100644 --- a/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/adaboost.py +++ b/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/adaboost.py @@ -7,6 +7,7 @@ from sklearn.tree import DecisionTreeClassifier from .. import metrics from ..monoview.monoview_utils import CustomRandint, BaseMonoviewClassifier, \ get_accuracy_graph +from ..utils.base import base_boosting_estimators # Author-Info __author__ = "Baptiste Bauvin" @@ -53,11 +54,11 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier): """ def __init__(self, random_state=None, n_estimators=50, - base_estimator=None, **kwargs): + base_estimator=None, base_estimator_config=None, **kwargs): - if isinstance(base_estimator, str): - if base_estimator == "DecisionTreeClassifier": - base_estimator = DecisionTreeClassifier() + base_estimator = BaseMonoviewClassifier.get_base_estimator(self, + base_estimator, + base_estimator_config) AdaBoostClassifier.__init__(self, random_state=random_state, n_estimators=n_estimators, @@ -67,7 +68,7 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier): self.param_names = ["n_estimators", "base_estimator"] self.classed_params = ["base_estimator"] self.distribs = [CustomRandint(low=1, high=500), - [DecisionTreeClassifier(max_depth=1)]] + base_boosting_estimators] self.weird_strings = {"base_estimator": "class_name"} self.plotted_metric = metrics.zero_one_loss self.plotted_metric_name = "zero_one_loss" diff --git a/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/gradient_boosting.py b/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/gradient_boosting.py index bf8cccb2f63c5a3372fe642cd9cc508e84efea23..4a3cae43e94f06ef4bf82d4ec4eeebfe3fbddb36 100644 --- a/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/gradient_boosting.py +++ b/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/gradient_boosting.py @@ -35,9 +35,10 @@ class GradientBoosting(GradientBoostingClassifier, BaseMonoviewClassifier): init=init, random_state=random_state ) - self.param_names = ["n_estimators", ] + self.param_names = ["n_estimators", "max_depth"] self.classed_params = [] - self.distribs = [CustomRandint(low=50, high=500), ] + self.distribs = [CustomRandint(low=50, high=500), + CustomRandint(low=1, high=10),] self.weird_strings = {} self.plotted_metric = metrics.zero_one_loss self.plotted_metric_name = "zero_one_loss" diff --git a/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/random_forest.py b/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/random_forest.py index 82a442d99c42ac96604ac36e4469fd1288eb0b6f..06bb25af3a61617e529ab43bb745165b7303e877 100644 --- a/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/random_forest.py +++ b/multiview_platform/mono_multi_view_classifiers/monoview_classifiers/random_forest.py @@ -61,7 +61,7 @@ class RandomForest(RandomForestClassifier, BaseMonoviewClassifier): "random_state"] self.classed_params = [] self.distribs = [CustomRandint(low=1, high=300), - CustomRandint(low=1, high=300), + CustomRandint(low=1, high=10), ["gini", "entropy"], [random_state]] self.weird_strings = {} diff --git a/multiview_platform/mono_multi_view_classifiers/utils/base.py b/multiview_platform/mono_multi_view_classifiers/utils/base.py index 6c2dd88efa3768f414aeeb0ae4bb4a9f124444f6..82e5cc236f569c189a42fb303d6e4e3cf24e09ef 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/base.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/base.py @@ -3,6 +3,9 @@ from sklearn.base import BaseEstimator from abc import abstractmethod from datetime import timedelta as hms +from sklearn.tree import DecisionTreeClassifier +from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier + from multiview_platform.mono_multi_view_classifiers import metrics @@ -60,6 +63,24 @@ class BaseClassifier(BaseEstimator, ): else: return self.__class__.__name__ + "with no config." + def get_base_estimator(self, base_estimator, estimator_config): + if base_estimator is None: + return DecisionTreeClassifier(**estimator_config) + if isinstance(base_estimator, str): + if base_estimator == "DecisionTreeClassifier": + return DecisionTreeClassifier(**estimator_config) + elif base_estimator == "AdaboostClassifier": + return AdaBoostClassifier(**estimator_config) + elif base_estimator == "RandomForestClassifier": + return RandomForestClassifier(**estimator_config) + else: + raise ValueError('Base estimator string {} does not match an available classifier.'.format(base_estimator)) + elif isinstance(base_estimator, BaseEstimator): + return base_estimator.set_params(**estimator_config) + else: + raise ValueError('base_estimator must be either a string or a BaseEstimator child class, it is {}'.format(type(base_estimator))) + + def to_str(self, param_name): """ Formats a parameter into a string @@ -317,3 +338,10 @@ class ResultAnalyser(): self.labels[self.test_indices]) image_analysis = {} return string_analysis, image_analysis, self.metric_scores + + +base_boosting_estimators = [DecisionTreeClassifier(max_depth=1), + DecisionTreeClassifier(max_depth=2), + DecisionTreeClassifier(max_depth=3), + DecisionTreeClassifier(max_depth=4), + DecisionTreeClassifier(max_depth=5), ] \ No newline at end of file diff --git a/multiview_platform/tests/test_utils/test_base.py b/multiview_platform/tests/test_utils/test_base.py index c4cd5998232bb92146053abf29cc614892d89a7a..f4f87316308f9fe766341f3068f156c7a47b0117 100644 --- a/multiview_platform/tests/test_utils/test_base.py +++ b/multiview_platform/tests/test_utils/test_base.py @@ -1,10 +1,43 @@ -# import os -# import unittest -# import yaml -# import numpy as np -# -# from multiview_platform.tests.utils import rm_tmp, tmp_path -# from multiview_platform.mono_multi_view_classifiers.utils import base -# -# -# class Test_ResultAnalyzer(unittest.TestCase): +import os +import unittest +import yaml +import numpy as np +from sklearn.tree import DecisionTreeClassifier + +from multiview_platform.tests.utils import rm_tmp, tmp_path +from multiview_platform.mono_multi_view_classifiers.utils import base + + +class Test_ResultAnalyzer(unittest.TestCase): + pass + +class Test_BaseEstimator(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.base_estimator = "DecisionTreeClassifier" + cls.base_estimator_config = {"max_depth":10, + "splitter": "best"} + cls.est = base.BaseClassifier() + + def test_simple(self): + base_estim = self.est.get_base_estimator(self.base_estimator, + self.base_estimator_config) + self.assertTrue(isinstance(base_estim, DecisionTreeClassifier)) + self.assertEqual(base_estim.max_depth, 10) + self.assertEqual(base_estim.splitter, "best") + + def test_class(self): + base_estimator = DecisionTreeClassifier(max_depth=15, splitter="random") + base_estim = self.est.get_base_estimator(base_estimator, + self.base_estimator_config) + self.assertTrue(isinstance(base_estim, DecisionTreeClassifier)) + self.assertEqual(base_estim.max_depth, 10) + self.assertEqual(base_estim.splitter, "best") + + def test_wrong_args(self): + base_estimator_config = {"n_estimators": 10, + "splitter": "best"} + with self.assertRaises(TypeError): + base_estim = self.est.get_base_estimator(self.base_estimator, + base_estimator_config) \ No newline at end of file