Skip to content
Snippets Groups Projects
Commit 6ff3a234 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added a base_estimator loader

parent b03bb8dc
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,7 @@ from sklearn.tree import DecisionTreeClassifier ...@@ -7,6 +7,7 @@ from sklearn.tree import DecisionTreeClassifier
from .. import metrics from .. import metrics
from ..monoview.monoview_utils import CustomRandint, BaseMonoviewClassifier, \ from ..monoview.monoview_utils import CustomRandint, BaseMonoviewClassifier, \
get_accuracy_graph get_accuracy_graph
from ..utils.base import base_boosting_estimators
# Author-Info # Author-Info
__author__ = "Baptiste Bauvin" __author__ = "Baptiste Bauvin"
...@@ -53,11 +54,11 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier): ...@@ -53,11 +54,11 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier):
""" """
def __init__(self, random_state=None, n_estimators=50, def __init__(self, random_state=None, n_estimators=50,
base_estimator=None, **kwargs): base_estimator=None, base_estimator_config=None, **kwargs):
if isinstance(base_estimator, str): base_estimator = BaseMonoviewClassifier.get_base_estimator(self,
if base_estimator == "DecisionTreeClassifier": base_estimator,
base_estimator = DecisionTreeClassifier() base_estimator_config)
AdaBoostClassifier.__init__(self, AdaBoostClassifier.__init__(self,
random_state=random_state, random_state=random_state,
n_estimators=n_estimators, n_estimators=n_estimators,
...@@ -67,7 +68,7 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier): ...@@ -67,7 +68,7 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier):
self.param_names = ["n_estimators", "base_estimator"] self.param_names = ["n_estimators", "base_estimator"]
self.classed_params = ["base_estimator"] self.classed_params = ["base_estimator"]
self.distribs = [CustomRandint(low=1, high=500), self.distribs = [CustomRandint(low=1, high=500),
[DecisionTreeClassifier(max_depth=1)]] base_boosting_estimators]
self.weird_strings = {"base_estimator": "class_name"} self.weird_strings = {"base_estimator": "class_name"}
self.plotted_metric = metrics.zero_one_loss self.plotted_metric = metrics.zero_one_loss
self.plotted_metric_name = "zero_one_loss" self.plotted_metric_name = "zero_one_loss"
......
...@@ -35,9 +35,10 @@ class GradientBoosting(GradientBoostingClassifier, BaseMonoviewClassifier): ...@@ -35,9 +35,10 @@ class GradientBoosting(GradientBoostingClassifier, BaseMonoviewClassifier):
init=init, init=init,
random_state=random_state random_state=random_state
) )
self.param_names = ["n_estimators", ] self.param_names = ["n_estimators", "max_depth"]
self.classed_params = [] self.classed_params = []
self.distribs = [CustomRandint(low=50, high=500), ] self.distribs = [CustomRandint(low=50, high=500),
CustomRandint(low=1, high=10),]
self.weird_strings = {} self.weird_strings = {}
self.plotted_metric = metrics.zero_one_loss self.plotted_metric = metrics.zero_one_loss
self.plotted_metric_name = "zero_one_loss" self.plotted_metric_name = "zero_one_loss"
......
...@@ -61,7 +61,7 @@ class RandomForest(RandomForestClassifier, BaseMonoviewClassifier): ...@@ -61,7 +61,7 @@ class RandomForest(RandomForestClassifier, BaseMonoviewClassifier):
"random_state"] "random_state"]
self.classed_params = [] self.classed_params = []
self.distribs = [CustomRandint(low=1, high=300), self.distribs = [CustomRandint(low=1, high=300),
CustomRandint(low=1, high=300), CustomRandint(low=1, high=10),
["gini", "entropy"], [random_state]] ["gini", "entropy"], [random_state]]
self.weird_strings = {} self.weird_strings = {}
......
...@@ -3,6 +3,9 @@ from sklearn.base import BaseEstimator ...@@ -3,6 +3,9 @@ from sklearn.base import BaseEstimator
from abc import abstractmethod from abc import abstractmethod
from datetime import timedelta as hms from datetime import timedelta as hms
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier
from multiview_platform.mono_multi_view_classifiers import metrics from multiview_platform.mono_multi_view_classifiers import metrics
...@@ -60,6 +63,24 @@ class BaseClassifier(BaseEstimator, ): ...@@ -60,6 +63,24 @@ class BaseClassifier(BaseEstimator, ):
else: else:
return self.__class__.__name__ + "with no config." return self.__class__.__name__ + "with no config."
def get_base_estimator(self, base_estimator, estimator_config):
if base_estimator is None:
return DecisionTreeClassifier(**estimator_config)
if isinstance(base_estimator, str):
if base_estimator == "DecisionTreeClassifier":
return DecisionTreeClassifier(**estimator_config)
elif base_estimator == "AdaboostClassifier":
return AdaBoostClassifier(**estimator_config)
elif base_estimator == "RandomForestClassifier":
return RandomForestClassifier(**estimator_config)
else:
raise ValueError('Base estimator string {} does not match an available classifier.'.format(base_estimator))
elif isinstance(base_estimator, BaseEstimator):
return base_estimator.set_params(**estimator_config)
else:
raise ValueError('base_estimator must be either a string or a BaseEstimator child class, it is {}'.format(type(base_estimator)))
def to_str(self, param_name): def to_str(self, param_name):
""" """
Formats a parameter into a string Formats a parameter into a string
...@@ -317,3 +338,10 @@ class ResultAnalyser(): ...@@ -317,3 +338,10 @@ class ResultAnalyser():
self.labels[self.test_indices]) self.labels[self.test_indices])
image_analysis = {} image_analysis = {}
return string_analysis, image_analysis, self.metric_scores return string_analysis, image_analysis, self.metric_scores
base_boosting_estimators = [DecisionTreeClassifier(max_depth=1),
DecisionTreeClassifier(max_depth=2),
DecisionTreeClassifier(max_depth=3),
DecisionTreeClassifier(max_depth=4),
DecisionTreeClassifier(max_depth=5), ]
\ No newline at end of file
# import os import os
# import unittest import unittest
# import yaml import yaml
# import numpy as np import numpy as np
# from sklearn.tree import DecisionTreeClassifier
# from multiview_platform.tests.utils import rm_tmp, tmp_path
# from multiview_platform.mono_multi_view_classifiers.utils import base from multiview_platform.tests.utils import rm_tmp, tmp_path
# from multiview_platform.mono_multi_view_classifiers.utils import base
#
# class Test_ResultAnalyzer(unittest.TestCase):
class Test_ResultAnalyzer(unittest.TestCase):
pass
class Test_BaseEstimator(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.base_estimator = "DecisionTreeClassifier"
cls.base_estimator_config = {"max_depth":10,
"splitter": "best"}
cls.est = base.BaseClassifier()
def test_simple(self):
base_estim = self.est.get_base_estimator(self.base_estimator,
self.base_estimator_config)
self.assertTrue(isinstance(base_estim, DecisionTreeClassifier))
self.assertEqual(base_estim.max_depth, 10)
self.assertEqual(base_estim.splitter, "best")
def test_class(self):
base_estimator = DecisionTreeClassifier(max_depth=15, splitter="random")
base_estim = self.est.get_base_estimator(base_estimator,
self.base_estimator_config)
self.assertTrue(isinstance(base_estim, DecisionTreeClassifier))
self.assertEqual(base_estim.max_depth, 10)
self.assertEqual(base_estim.splitter, "best")
def test_wrong_args(self):
base_estimator_config = {"n_estimators": 10,
"splitter": "best"}
with self.assertRaises(TypeError):
base_estim = self.est.get_base_estimator(self.base_estimator,
base_estimator_config)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment