Skip to content
Snippets Groups Projects
Commit de646719 authored by bbauvin's avatar bbauvin
Browse files

Wrote some tests

parent 689b8746
Branches
No related tags found
No related merge requests found
......@@ -52,3 +52,22 @@ class Test_initConstants(unittest.TestCase):
os.rmdir("Code/Tests/temp_tests/test_dir/test_clf")
os.rmdir("Code/Tests/temp_tests/test_dir")
os.rmdir("Code/Tests/temp_tests")
class Test_initTrainTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.random_state = np.random.RandomState(42)
cls.X = cls.random_state.randint(0,500,(10,5))
print(cls.X)
cls.Y = cls.random_state.randint(0,2,10)
print(cls.Y)
cls.classificationIndices = [np.array([0,2,4,6,8]),np.array([1,3,5,7,9])]
import pdb; pdb.set_trace()
def test_simple(cls):
X_train, y_train, X_test, y_test = ExecClassifMonoView.initTrainTest(cls.X, cls.Y, cls.classificationIndices)
np.testing.assert_array_equal(X_train, np.array([np.array([]), np.array([]), np.array([]), np.array([]), np.array([])]))
np.testing.assert_array_equal(X_test, np.array([np.array([]), np.array([]), np.array([]), np.array([]), np.array([])]))
np.testing.assert_array_equal(y_train, np.array([]))
np.testing.assert_array_equal(y_test, np.array([]))
\ No newline at end of file
......@@ -61,7 +61,6 @@ class Test_genSplits(unittest.TestCase):
self.labels[self.X_indices[11:30]] = 2 # To test multiclass
self.splitRatio = 0.2
def test_simple(self):
splits = execution.genSplits(self.labels, self.splitRatio, self.statsIterRandomStates)
self.assertEqual(len(splits), 3)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment