Skip to content
Snippets Groups Projects
Commit 6ff3a234 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added a base_estimator loader

parent b03bb8dc
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@ from sklearn.tree import DecisionTreeClassifier
from .. import metrics
from ..monoview.monoview_utils import CustomRandint, BaseMonoviewClassifier, \
get_accuracy_graph
from ..utils.base import base_boosting_estimators
# Author-Info
__author__ = "Baptiste Bauvin"
......@@ -53,11 +54,11 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier):
"""
def __init__(self, random_state=None, n_estimators=50,
base_estimator=None, **kwargs):
base_estimator=None, base_estimator_config=None, **kwargs):
if isinstance(base_estimator, str):
if base_estimator == "DecisionTreeClassifier":
base_estimator = DecisionTreeClassifier()
base_estimator = BaseMonoviewClassifier.get_base_estimator(self,
base_estimator,
base_estimator_config)
AdaBoostClassifier.__init__(self,
random_state=random_state,
n_estimators=n_estimators,
......@@ -67,7 +68,7 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier):
self.param_names = ["n_estimators", "base_estimator"]
self.classed_params = ["base_estimator"]
self.distribs = [CustomRandint(low=1, high=500),
[DecisionTreeClassifier(max_depth=1)]]
base_boosting_estimators]
self.weird_strings = {"base_estimator": "class_name"}
self.plotted_metric = metrics.zero_one_loss
self.plotted_metric_name = "zero_one_loss"
......
......@@ -35,9 +35,10 @@ class GradientBoosting(GradientBoostingClassifier, BaseMonoviewClassifier):
init=init,
random_state=random_state
)
self.param_names = ["n_estimators", ]
self.param_names = ["n_estimators", "max_depth"]
self.classed_params = []
self.distribs = [CustomRandint(low=50, high=500), ]
self.distribs = [CustomRandint(low=50, high=500),
CustomRandint(low=1, high=10),]
self.weird_strings = {}
self.plotted_metric = metrics.zero_one_loss
self.plotted_metric_name = "zero_one_loss"
......
......@@ -61,7 +61,7 @@ class RandomForest(RandomForestClassifier, BaseMonoviewClassifier):
"random_state"]
self.classed_params = []
self.distribs = [CustomRandint(low=1, high=300),
CustomRandint(low=1, high=300),
CustomRandint(low=1, high=10),
["gini", "entropy"], [random_state]]
self.weird_strings = {}
......
......@@ -3,6 +3,9 @@ from sklearn.base import BaseEstimator
from abc import abstractmethod
from datetime import timedelta as hms
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier
from multiview_platform.mono_multi_view_classifiers import metrics
......@@ -60,6 +63,24 @@ class BaseClassifier(BaseEstimator, ):
else:
return self.__class__.__name__ + "with no config."
def get_base_estimator(self, base_estimator, estimator_config):
if base_estimator is None:
return DecisionTreeClassifier(**estimator_config)
if isinstance(base_estimator, str):
if base_estimator == "DecisionTreeClassifier":
return DecisionTreeClassifier(**estimator_config)
elif base_estimator == "AdaboostClassifier":
return AdaBoostClassifier(**estimator_config)
elif base_estimator == "RandomForestClassifier":
return RandomForestClassifier(**estimator_config)
else:
raise ValueError('Base estimator string {} does not match an available classifier.'.format(base_estimator))
elif isinstance(base_estimator, BaseEstimator):
return base_estimator.set_params(**estimator_config)
else:
raise ValueError('base_estimator must be either a string or a BaseEstimator child class, it is {}'.format(type(base_estimator)))
def to_str(self, param_name):
"""
Formats a parameter into a string
......@@ -317,3 +338,10 @@ class ResultAnalyser():
self.labels[self.test_indices])
image_analysis = {}
return string_analysis, image_analysis, self.metric_scores
base_boosting_estimators = [DecisionTreeClassifier(max_depth=1),
DecisionTreeClassifier(max_depth=2),
DecisionTreeClassifier(max_depth=3),
DecisionTreeClassifier(max_depth=4),
DecisionTreeClassifier(max_depth=5), ]
\ No newline at end of file
# import os
# import unittest
# import yaml
# import numpy as np
#
# from multiview_platform.tests.utils import rm_tmp, tmp_path
# from multiview_platform.mono_multi_view_classifiers.utils import base
#
#
# class Test_ResultAnalyzer(unittest.TestCase):
import os
import unittest
import yaml
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from multiview_platform.tests.utils import rm_tmp, tmp_path
from multiview_platform.mono_multi_view_classifiers.utils import base
class Test_ResultAnalyzer(unittest.TestCase):
pass
class Test_BaseEstimator(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.base_estimator = "DecisionTreeClassifier"
cls.base_estimator_config = {"max_depth":10,
"splitter": "best"}
cls.est = base.BaseClassifier()
def test_simple(self):
base_estim = self.est.get_base_estimator(self.base_estimator,
self.base_estimator_config)
self.assertTrue(isinstance(base_estim, DecisionTreeClassifier))
self.assertEqual(base_estim.max_depth, 10)
self.assertEqual(base_estim.splitter, "best")
def test_class(self):
base_estimator = DecisionTreeClassifier(max_depth=15, splitter="random")
base_estim = self.est.get_base_estimator(base_estimator,
self.base_estimator_config)
self.assertTrue(isinstance(base_estim, DecisionTreeClassifier))
self.assertEqual(base_estim.max_depth, 10)
self.assertEqual(base_estim.splitter, "best")
def test_wrong_args(self):
base_estimator_config = {"n_estimators": 10,
"splitter": "best"}
with self.assertRaises(TypeError):
base_estim = self.est.get_base_estimator(self.base_estimator,
base_estimator_config)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment