Skip to content
Snippets Groups Projects
Commit f67d5e07 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added tests for most of adaboost

parent ece0e18a
Branches
Tags
No related merge requests found
......@@ -58,12 +58,14 @@ def randomizedSearch(X_train, y_train, randomState, outputFileName, KFolds=4, me
param = {"classifier__n_estimators": randint(1, 150),
"classifier__base_estimator": [DecisionTreeClassifier()]}
metricModule = getattr(Metrics, metric[0])
if metric[1] is not None:
metricKWARGS = dict((index, metricConfig) for index, metricConfig in enumerate(metric[1]))
else:
metricKWARGS = {}
scorer = metricModule.get_scorer(**metricKWARGS)
grid = RandomizedSearchCV(pipeline, n_iter=nIter, param_distributions=param, refit=True, n_jobs=nbCores,
scoring=scorer, cv=KFolds, random_state=randomState)
detector = grid.fit(X_train, y_train)
......
import unittest
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from ...MonoMultiViewClassifiers.MonoviewClassifiers import Adaboost
class Test_canProbas(unittest.TestCase):
def test_simple(cls):
cls.assertTrue(Adaboost.canProbas())
class Test_paramsToSet(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.n_iter = 4
cls.random_state = np.random.RandomState(42)
def test_simple(cls):
res = Adaboost.paramsToSet(cls.n_iter, cls.random_state)
cls.assertEqual(len(res), cls.n_iter)
cls.assertEqual(type(res[0][0]), int)
cls.assertEqual(type(res[0][1]), type(DecisionTreeClassifier()))
cls.assertEqual([7,4,13,11], [resIter[0] for resIter in res])
class Test_getKWARGS(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.kwargs_list = [("CL_Adaboost_n_est", 10),
("CL_Adaboost_b_est", DecisionTreeClassifier())]
def test_simple(cls):
res = Adaboost.getKWARGS(cls.kwargs_list)
cls.assertIn("0", res)
cls.assertIn("1", res)
cls.assertEqual(type(res), dict)
cls.assertEqual(res["0"], 10)
# Can't test decision tree
def test_wrong(cls):
cls.kwargs_list[0] = ("chicken_is_heaven",42)
with cls.assertRaises(ValueError) as catcher:
Adaboost.getKWARGS(cls.kwargs_list)
exception = catcher.exception
# cls.assertEqual(exception, "Wrong arguments served to Adaboost")
class Test_randomizedSearch(unittest.TestCase):
def test_simple(cls):
pass # Test with simple params
class Test_fit(unittest.TestCase):
def setUp(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment