Skip to content
Snippets Groups Projects
Commit 2e788d38 authored by bbauvin's avatar bbauvin
Browse files

Wrote some tests

parent de646719
No related branches found
No related tags found
No related merge requests found
...@@ -47,6 +47,7 @@ def initConstants(args, X, classificationIndices, labelsNames, name, directory): ...@@ -47,6 +47,7 @@ def initConstants(args, X, classificationIndices, labelsNames, name, directory):
raise raise
return kwargs, t_start, feat, CL_type, X, learningRate, labelsString, outputFileName return kwargs, t_start, feat, CL_type, X, learningRate, labelsString, outputFileName
def initTrainTest(X, Y, classificationIndices): def initTrainTest(X, Y, classificationIndices):
trainIndices, testIndices = classificationIndices trainIndices, testIndices = classificationIndices
X_train = extractSubset(X, trainIndices) X_train = extractSubset(X, trainIndices)
...@@ -55,19 +56,21 @@ def initTrainTest(X, Y, classificationIndices): ...@@ -55,19 +56,21 @@ def initTrainTest(X, Y, classificationIndices):
y_test = Y[testIndices] y_test = Y[testIndices]
return X_train, y_train, X_test, y_test return X_train, y_train, X_test, y_test
def getKWARGS(classifierModule, hyperParamSearch, nIter, CL_type, X_train, y_train, randomState, def getKWARGS(classifierModule, hyperParamSearch, nIter, CL_type, X_train, y_train, randomState,
outputFileName, KFolds, nbCores, metrics, kwargs): outputFileName, KFolds, nbCores, metrics, kwargs):
if hyperParamSearch != "None": if hyperParamSearch != "None":
logging.debug("Start:\t " + hyperParamSearch + " best settings with " + str(nIter) + " iterations for " + CL_type)
classifierHPSearch = getattr(classifierModule, hyperParamSearch) classifierHPSearch = getattr(classifierModule, hyperParamSearch)
logging.debug("Start:\t RandomSearch best settings with " + str(nIter) + " iterations for " + CL_type)
cl_desc = classifierHPSearch(X_train, y_train, randomState, outputFileName, KFolds=KFolds, nbCores=nbCores, cl_desc = classifierHPSearch(X_train, y_train, randomState, outputFileName, KFolds=KFolds, nbCores=nbCores,
metric=metrics[0], nIter=nIter) metric=metrics[0], nIter=nIter)
clKWARGS = dict((str(index), desc) for index, desc in enumerate(cl_desc)) clKWARGS = dict((str(index), desc) for index, desc in enumerate(cl_desc))
logging.debug("Done:\t RandomSearch best settings") logging.debug("Done:\t " + hyperParamSearch + "RandomSearch best settings")
else: else:
clKWARGS = kwargs[CL_type + "KWARGS"] clKWARGS = kwargs[CL_type + "KWARGS"]
return clKWARGS return clKWARGS
def saveResults(stringAnalysis, outputFileName, full_labels_pred, y_train_pred, y_train, imagesAnalysis): def saveResults(stringAnalysis, outputFileName, full_labels_pred, y_train_pred, y_train, imagesAnalysis):
logging.info(stringAnalysis) logging.info(stringAnalysis)
outputTextFile = open(outputFileName + '.txt', 'w') outputTextFile = open(outputFileName + '.txt', 'w')
......
...@@ -53,21 +53,27 @@ class Test_initConstants(unittest.TestCase): ...@@ -53,21 +53,27 @@ class Test_initConstants(unittest.TestCase):
os.rmdir("Code/Tests/temp_tests/test_dir") os.rmdir("Code/Tests/temp_tests/test_dir")
os.rmdir("Code/Tests/temp_tests") os.rmdir("Code/Tests/temp_tests")
class Test_initTrainTest(unittest.TestCase): class Test_initTrainTest(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.random_state = np.random.RandomState(42) cls.random_state = np.random.RandomState(42)
cls.X = cls.random_state.randint(0,500,(10,5)) cls.X = cls.random_state.randint(0,500,(10,5))
print(cls.X)
cls.Y = cls.random_state.randint(0,2,10) cls.Y = cls.random_state.randint(0,2,10)
print(cls.Y)
cls.classificationIndices = [np.array([0,2,4,6,8]),np.array([1,3,5,7,9])] cls.classificationIndices = [np.array([0,2,4,6,8]),np.array([1,3,5,7,9])]
import pdb; pdb.set_trace()
def test_simple(cls): def test_simple(cls):
X_train, y_train, X_test, y_test = ExecClassifMonoView.initTrainTest(cls.X, cls.Y, cls.classificationIndices) X_train, y_train, X_test, y_test = ExecClassifMonoView.initTrainTest(cls.X, cls.Y, cls.classificationIndices)
np.testing.assert_array_equal(X_train, np.array([np.array([]), np.array([]), np.array([]), np.array([]), np.array([])])) np.testing.assert_array_equal(X_train, np.array([np.array([102,435,348,270,106]),
np.testing.assert_array_equal(X_test, np.array([np.array([]), np.array([]), np.array([]), np.array([]), np.array([])])) np.array([466,214,330,458,87]),
np.testing.assert_array_equal(y_train, np.array([])) np.array([149,308,257,343,491]),
np.testing.assert_array_equal(y_test, np.array([])) np.array([276,160,459,313,21]),
\ No newline at end of file np.array([58,169,475,187,463])]))
np.testing.assert_array_equal(X_test, np.array([np.array([71,188,20,102,121]),
np.array([372,99,359,151,130]),
np.array([413,293,385,191,443]),
np.array([252,235,344,48,474]),
np.array([270,189,445,174,445])]))
np.testing.assert_array_equal(y_train, np.array([0,0,1,0,0]))
np.testing.assert_array_equal(y_test, np.array([1,1,0,0,0]))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment