Skip to content
Snippets Groups Projects
Commit 2e788d38 authored by bbauvin's avatar bbauvin
Browse files

Wrote some tests

parent de646719
Branches
Tags
No related merge requests found
......@@ -47,6 +47,7 @@ def initConstants(args, X, classificationIndices, labelsNames, name, directory):
raise
return kwargs, t_start, feat, CL_type, X, learningRate, labelsString, outputFileName
def initTrainTest(X, Y, classificationIndices):
trainIndices, testIndices = classificationIndices
X_train = extractSubset(X, trainIndices)
......@@ -55,19 +56,21 @@ def initTrainTest(X, Y, classificationIndices):
y_test = Y[testIndices]
return X_train, y_train, X_test, y_test
def getKWARGS(classifierModule, hyperParamSearch, nIter, CL_type, X_train, y_train, randomState,
outputFileName, KFolds, nbCores, metrics, kwargs):
if hyperParamSearch != "None":
logging.debug("Start:\t " + hyperParamSearch + " best settings with " + str(nIter) + " iterations for " + CL_type)
classifierHPSearch = getattr(classifierModule, hyperParamSearch)
logging.debug("Start:\t RandomSearch best settings with " + str(nIter) + " iterations for " + CL_type)
cl_desc = classifierHPSearch(X_train, y_train, randomState, outputFileName, KFolds=KFolds, nbCores=nbCores,
metric=metrics[0], nIter=nIter)
clKWARGS = dict((str(index), desc) for index, desc in enumerate(cl_desc))
logging.debug("Done:\t RandomSearch best settings")
logging.debug("Done:\t " + hyperParamSearch + "RandomSearch best settings")
else:
clKWARGS = kwargs[CL_type + "KWARGS"]
return clKWARGS
def saveResults(stringAnalysis, outputFileName, full_labels_pred, y_train_pred, y_train, imagesAnalysis):
logging.info(stringAnalysis)
outputTextFile = open(outputFileName + '.txt', 'w')
......
......@@ -53,21 +53,27 @@ class Test_initConstants(unittest.TestCase):
os.rmdir("Code/Tests/temp_tests/test_dir")
os.rmdir("Code/Tests/temp_tests")
class Test_initTrainTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.random_state = np.random.RandomState(42)
cls.X = cls.random_state.randint(0,500,(10,5))
print(cls.X)
cls.Y = cls.random_state.randint(0,2,10)
print(cls.Y)
cls.classificationIndices = [np.array([0,2,4,6,8]),np.array([1,3,5,7,9])]
import pdb; pdb.set_trace()
def test_simple(cls):
X_train, y_train, X_test, y_test = ExecClassifMonoView.initTrainTest(cls.X, cls.Y, cls.classificationIndices)
np.testing.assert_array_equal(X_train, np.array([np.array([]), np.array([]), np.array([]), np.array([]), np.array([])]))
np.testing.assert_array_equal(X_test, np.array([np.array([]), np.array([]), np.array([]), np.array([]), np.array([])]))
np.testing.assert_array_equal(y_train, np.array([]))
np.testing.assert_array_equal(y_test, np.array([]))
\ No newline at end of file
np.testing.assert_array_equal(X_train, np.array([np.array([102,435,348,270,106]),
np.array([466,214,330,458,87]),
np.array([149,308,257,343,491]),
np.array([276,160,459,313,21]),
np.array([58,169,475,187,463])]))
np.testing.assert_array_equal(X_test, np.array([np.array([71,188,20,102,121]),
np.array([372,99,359,151,130]),
np.array([413,293,385,191,443]),
np.array([252,235,344,48,474]),
np.array([270,189,445,174,445])]))
np.testing.assert_array_equal(y_train, np.array([0,0,1,0,0]))
np.testing.assert_array_equal(y_test, np.array([1,1,0,0,0]))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment