diff --git a/Code/MonoMultiViewClassifiers/MonoviewClassifiers/Adaboost.py b/Code/MonoMultiViewClassifiers/MonoviewClassifiers/Adaboost.py index 1fa0b546c8af374e91b0badd0b42115aed9bbb8c..2464619d82c6d933690624db24c0b2ac42f5c446 100644 --- a/Code/MonoMultiViewClassifiers/MonoviewClassifiers/Adaboost.py +++ b/Code/MonoMultiViewClassifiers/MonoviewClassifiers/Adaboost.py @@ -58,12 +58,14 @@ def randomizedSearch(X_train, y_train, randomState, outputFileName, KFolds=4, me param = {"classifier__n_estimators": randint(1, 150), "classifier__base_estimator": [DecisionTreeClassifier()]} + metricModule = getattr(Metrics, metric[0]) if metric[1] is not None: metricKWARGS = dict((index, metricConfig) for index, metricConfig in enumerate(metric[1])) else: metricKWARGS = {} scorer = metricModule.get_scorer(**metricKWARGS) + grid = RandomizedSearchCV(pipeline, n_iter=nIter, param_distributions=param, refit=True, n_jobs=nbCores, scoring=scorer, cv=KFolds, random_state=randomState) detector = grid.fit(X_train, y_train) @@ -79,7 +81,7 @@ def randomizedSearch(X_train, y_train, randomState, outputFileName, KFolds=4, me def getConfig(config): - if type(config) not in [list, dict]: # Used in late fusion when config is a classifier + if type(config) not in [list, dict]: # Used in late fusion when config is a classifier return "\n\t\t- Adaboost with num_esimators : " + str(config.n_estimators) + ", base_estimators : " + str( config.base_estimator) else: diff --git a/Code/Tests/Test_MonoviewClassifiers/test_Adaboost.py b/Code/Tests/Test_MonoviewClassifiers/test_Adaboost.py index 1f2e2ee79eead2dbf9895b5de4a3ae536085daf7..8385f0d8f0ec084fe597bd5c4f815bfb9c835307 100644 --- a/Code/Tests/Test_MonoviewClassifiers/test_Adaboost.py +++ b/Code/Tests/Test_MonoviewClassifiers/test_Adaboost.py @@ -1,9 +1,60 @@ import unittest import numpy as np +from sklearn.tree import DecisionTreeClassifier from ...MonoMultiViewClassifiers.MonoviewClassifiers import Adaboost +class Test_canProbas(unittest.TestCase): + + def test_simple(cls): + cls.assertTrue(Adaboost.canProbas()) + + +class Test_paramsToSet(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.n_iter = 4 + cls.random_state = np.random.RandomState(42) + + def test_simple(cls): + res = Adaboost.paramsToSet(cls.n_iter, cls.random_state) + cls.assertEqual(len(res), cls.n_iter) + cls.assertEqual(type(res[0][0]), int) + cls.assertEqual(type(res[0][1]), type(DecisionTreeClassifier())) + cls.assertEqual([7,4,13,11], [resIter[0] for resIter in res]) + + +class Test_getKWARGS(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.kwargs_list = [("CL_Adaboost_n_est", 10), + ("CL_Adaboost_b_est", DecisionTreeClassifier())] + + def test_simple(cls): + res = Adaboost.getKWARGS(cls.kwargs_list) + cls.assertIn("0", res) + cls.assertIn("1", res) + cls.assertEqual(type(res), dict) + cls.assertEqual(res["0"], 10) + # Can't test decision tree + + def test_wrong(cls): + cls.kwargs_list[0] = ("chicken_is_heaven",42) + with cls.assertRaises(ValueError) as catcher: + Adaboost.getKWARGS(cls.kwargs_list) + exception = catcher.exception + # cls.assertEqual(exception, "Wrong arguments served to Adaboost") + + +class Test_randomizedSearch(unittest.TestCase): + + def test_simple(cls): + pass # Test with simple params + + class Test_fit(unittest.TestCase): def setUp(self):