Skip to content
Snippets Groups Projects
Commit 2e788d38 authored by bbauvin's avatar bbauvin
Browse files

Wrote some tests

parent de646719
No related branches found
No related tags found
No related merge requests found
......@@ -47,6 +47,7 @@ def initConstants(args, X, classificationIndices, labelsNames, name, directory):
raise
return kwargs, t_start, feat, CL_type, X, learningRate, labelsString, outputFileName
def initTrainTest(X, Y, classificationIndices):
trainIndices, testIndices = classificationIndices
X_train = extractSubset(X, trainIndices)
......@@ -55,19 +56,21 @@ def initTrainTest(X, Y, classificationIndices):
y_test = Y[testIndices]
return X_train, y_train, X_test, y_test
def getKWARGS(classifierModule, hyperParamSearch, nIter, CL_type, X_train, y_train, randomState,
outputFileName, KFolds, nbCores, metrics, kwargs):
if hyperParamSearch != "None":
logging.debug("Start:\t " + hyperParamSearch + " best settings with " + str(nIter) + " iterations for " + CL_type)
classifierHPSearch = getattr(classifierModule, hyperParamSearch)
logging.debug("Start:\t RandomSearch best settings with " + str(nIter) + " iterations for " + CL_type)
cl_desc = classifierHPSearch(X_train, y_train, randomState, outputFileName, KFolds=KFolds, nbCores=nbCores,
metric=metrics[0], nIter=nIter)
clKWARGS = dict((str(index), desc) for index, desc in enumerate(cl_desc))
logging.debug("Done:\t RandomSearch best settings")
logging.debug("Done:\t " + hyperParamSearch + "RandomSearch best settings")
else:
clKWARGS = kwargs[CL_type + "KWARGS"]
return clKWARGS
def saveResults(stringAnalysis, outputFileName, full_labels_pred, y_train_pred, y_train, imagesAnalysis):
logging.info(stringAnalysis)
outputTextFile = open(outputFileName + '.txt', 'w')
......
......@@ -53,21 +53,27 @@ class Test_initConstants(unittest.TestCase):
os.rmdir("Code/Tests/temp_tests/test_dir")
os.rmdir("Code/Tests/temp_tests")
class Test_initTrainTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.random_state = np.random.RandomState(42)
cls.X = cls.random_state.randint(0,500,(10,5))
print(cls.X)
cls.Y = cls.random_state.randint(0,2,10)
print(cls.Y)
cls.classificationIndices = [np.array([0,2,4,6,8]),np.array([1,3,5,7,9])]
import pdb; pdb.set_trace()
def test_simple(cls):
X_train, y_train, X_test, y_test = ExecClassifMonoView.initTrainTest(cls.X, cls.Y, cls.classificationIndices)
np.testing.assert_array_equal(X_train, np.array([np.array([]), np.array([]), np.array([]), np.array([]), np.array([])]))
np.testing.assert_array_equal(X_test, np.array([np.array([]), np.array([]), np.array([]), np.array([]), np.array([])]))
np.testing.assert_array_equal(y_train, np.array([]))
np.testing.assert_array_equal(y_test, np.array([]))
\ No newline at end of file
np.testing.assert_array_equal(X_train, np.array([np.array([102,435,348,270,106]),
np.array([466,214,330,458,87]),
np.array([149,308,257,343,491]),
np.array([276,160,459,313,21]),
np.array([58,169,475,187,463])]))
np.testing.assert_array_equal(X_test, np.array([np.array([71,188,20,102,121]),
np.array([372,99,359,151,130]),
np.array([413,293,385,191,443]),
np.array([252,235,344,48,474]),
np.array([270,189,445,174,445])]))
np.testing.assert_array_equal(y_train, np.array([0,0,1,0,0]))
np.testing.assert_array_equal(y_test, np.array([1,1,0,0,0]))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment