Skip to content
Snippets Groups Projects
Commit e200099a authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added tests progressively

parent de914a4a
No related branches found
No related tags found
No related merge requests found
...@@ -344,15 +344,9 @@ def execClassif(arguments): ...@@ -344,15 +344,9 @@ def execClassif(arguments):
directory = execution.initLogFile(args) directory = execution.initLogFile(args)
randomState = execution.initRandomState(args.randomState, directory) randomState = execution.initRandomState(args.randomState, directory)
if statsIter > 1: statsIterRandomStates = execution.initStatsIterRandomStates(statsIter,randomState)
statsIterRandomStates = [np.random.RandomState(randomState.randint(500)) for _ in range(statsIter)]
else:
statsIterRandomStates = [randomState]
if args.name not in ["Fake", "Plausible"]: getDatabase = execution.getDatabaseFunction(args.name,args.type)
getDatabase = getattr(DB, "getClassicDB" + args.type[1:])
else:
getDatabase = getattr(DB, "get" + args.name + "DB" + args.type[1:])
DATASET, LABELS_DICTIONARY = getDatabase(args.views, args.pathF, args.name, args.CL_nbClass, DATASET, LABELS_DICTIONARY = getDatabase(args.views, args.pathF, args.name, args.CL_nbClass,
args.CL_classes, randomState, args.full) args.CL_classes, randomState, args.full)
......
...@@ -8,6 +8,8 @@ import logging ...@@ -8,6 +8,8 @@ import logging
import sklearn import sklearn
from . import GetMultiviewDb as DB
def parseTheArgs(arguments): def parseTheArgs(arguments):
"""Used to parse the args entered by the user""" """Used to parse the args entered by the user"""
...@@ -220,7 +222,14 @@ def parseTheArgs(arguments): ...@@ -220,7 +222,14 @@ def parseTheArgs(arguments):
def initRandomState(randomStateArg, directory): def initRandomState(randomStateArg, directory):
"""Used to init a random state and multiple if needed (multicore)""" """
Used to init a random state.
If no randomState is specified, it will take a 'random' seed.
If the arg is a string containing only numbers, it will be converted in an int to gen a seed.
If the arg is a string with letters, it must be a path to a pickled random state file that will be loaded.
The function will also pickle the new random state in a file tobe able to retrieve it later.
Tested
"""
if randomStateArg is None: if randomStateArg is None:
randomState = np.random.RandomState(randomStateArg) randomState = np.random.RandomState(randomStateArg)
else: else:
...@@ -236,25 +245,32 @@ def initRandomState(randomStateArg, directory): ...@@ -236,25 +245,32 @@ def initRandomState(randomStateArg, directory):
return randomState return randomState
def initStatsIterRandomStates(statsIter, randomState):
"""Used to init multiple random states if needed because of multiple statsIter"""
if statsIter > 1:
statsIterRandomStates = [np.random.RandomState(randomState.randint(5000)) for _ in range(statsIter)]
else:
statsIterRandomStates = [randomState]
return statsIterRandomStates
def getDatabaseFunction(name, type):
"""Used to get the right databes extraction function according to the type of and it's name"""
if name not in ["Fake", "Plausible"]:
getDatabase = getattr(DB, "getClassicDB" + type[1:])
else:
getDatabase = getattr(DB, "get" + name + "DB" + type[1:])
return getDatabase
def initLogFile(args): def initLogFile(args):
"""Used to init the directory where the preds will be stored and the log file""" """Used to init the directory where the preds will be stored and the log file"""
resultDirectory = "../Results/" + args.name + "/started_" + time.strftime("%Y_%m_%d-%H_%M") + "/" resultDirectory = "../Results/" + args.name + "/started_" + time.strftime("%Y_%m_%d-%H_%M") + "/"
logFileName = time.strftime("%Y_%m_%d-%H_%M_%S") + "-" + ''.join(args.CL_type) + "-" + "_".join( logFileName = time.strftime("%Y_%m_%d-%H_%M") + "-" + ''.join(args.CL_type) + "-" + "_".join(
args.views) + "-" + args.name + "-LOG" args.views) + "-" + args.name + "-LOG"
if not os.path.exists(os.path.dirname(resultDirectory + logFileName)): if os.path.exists(os.path.dirname(resultDirectory)):
try: raise NameError("The result dir already exists, wait 1 min and retry")
os.makedirs(os.path.dirname(resultDirectory + logFileName))
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
logFile = resultDirectory + logFileName logFile = resultDirectory + logFileName
if os.path.isfile(logFile + ".log"):
for i in range(1, 20):
testFileName = logFileName + "-" + str(i) + ".log"
if not (os.path.isfile(resultDirectory + testFileName)):
logFile = resultDirectory + testFileName
break
else:
logFile += ".log" logFile += ".log"
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', filename=logFile, level=logging.DEBUG, logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', filename=logFile, level=logging.DEBUG,
filemode='w') filemode='w')
......
...@@ -31,8 +31,6 @@ class Test_copyhdf5Dataset(unittest.TestCase): ...@@ -31,8 +31,6 @@ class Test_copyhdf5Dataset(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
os.remove("multiview_platform/Tests/temp_tests/test_copy.hdf5") os.remove("multiview_platform/Tests/temp_tests/test_copy.hdf5")
# for dir in os.listdir("multiview_platform/Tests/temp_tests"):print(dir)
# import pdb;pdb.set_trace()
os.rmdir("multiview_platform/Tests/temp_tests") os.rmdir("multiview_platform/Tests/temp_tests")
...@@ -126,8 +124,7 @@ class Test_filterLabels(unittest.TestCase): ...@@ -126,8 +124,7 @@ class Test_filterLabels(unittest.TestCase):
cls.availableLabelsNames, cls.availableLabelsNames,
cls.askedLabelsNames) cls.askedLabelsNames)
exception = catcher.exception exception = catcher.exception
# cls.assertTrue("Asked more labels than available in the dataset. Available labels are : test_label_0, test_label_1, test_label_2, test_label_3" in exception)
#
def test_asked_all_labels(cls): def test_asked_all_labels(cls):
cls.askedLabelsNamesSet = {"test_label_0", "test_label_1", "test_label_2", "test_label_3"} cls.askedLabelsNamesSet = {"test_label_0", "test_label_1", "test_label_2", "test_label_3"}
cls.askedLabelsNames = ["test_label_0", "test_label_1", "test_label_2", "test_label_3"] cls.askedLabelsNames = ["test_label_0", "test_label_1", "test_label_2", "test_label_3"]
......
...@@ -3,6 +3,8 @@ import argparse ...@@ -3,6 +3,8 @@ import argparse
import os import os
import h5py import h5py
import numpy as np import numpy as np
import shutil
import time
from sklearn.model_selection import StratifiedShuffleSplit from sklearn.model_selection import StratifiedShuffleSplit
...@@ -19,6 +21,52 @@ class Test_parseTheArgs(unittest.TestCase): ...@@ -19,6 +21,52 @@ class Test_parseTheArgs(unittest.TestCase):
# print args # print args
class Test_initStatsIterRandomStates(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.randomState = np.random.RandomState(42)
cls.statsIter = 1
def test_one_statiter(cls):
cls.state = cls.randomState.get_state()[1]
statsIterRandomStates = execution.initStatsIterRandomStates(cls.statsIter, cls.randomState)
np.testing.assert_array_equal(statsIterRandomStates[0].get_state()[1], cls.state)
def test_multiple_iter(cls):
cls.statsIter = 3
statsIterRandomStates = execution.initStatsIterRandomStates(cls.statsIter, cls.randomState)
cls.assertAlmostEqual(len(statsIterRandomStates), 3)
cls.assertNotEqual(statsIterRandomStates[0].randint(5000), statsIterRandomStates[1].randint(5000))
cls.assertNotEqual(statsIterRandomStates[0].randint(5000), statsIterRandomStates[2].randint(5000))
cls.assertNotEqual(statsIterRandomStates[2].randint(5000), statsIterRandomStates[1].randint(5000))
class Test_getDatabaseFunction(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.name = "zrtTap"
cls.type = ".csv"
def test_simple(cls):
getDB = execution.getDatabaseFunction(cls.name, cls.type)
from ...MonoMultiViewClassifiers.utils.GetMultiviewDb import getClassicDBcsv
cls.assertEqual(getDB, getClassicDBcsv)
def test_hdf5(cls):
cls.type = ".hdf5"
getDB = execution.getDatabaseFunction(cls.name, cls.type)
from ...MonoMultiViewClassifiers.utils.GetMultiviewDb import getClassicDBhdf5
cls.assertEqual(getDB, getClassicDBhdf5)
def test_plausible_hdf5(cls):
cls.name = "Plausible"
cls.type = ".hdf5"
getDB = execution.getDatabaseFunction(cls.name, cls.type)
from ...MonoMultiViewClassifiers.utils.GetMultiviewDb import getPlausibleDBhdf5
cls.assertEqual(getDB, getPlausibleDBhdf5)
class Test_initRandomState(unittest.TestCase): class Test_initRandomState(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -43,10 +91,34 @@ class Test_initRandomState(unittest.TestCase): ...@@ -43,10 +91,34 @@ class Test_initRandomState(unittest.TestCase):
pickled_randomState.beta(1,100,100)) pickled_randomState.beta(1,100,100))
class Test_initLogFile(unittest.TestCase): class FakeArg():
def test_initLogFile(self): def __init__(self):
pass self.name = "zrtTap"
self.CL_type = ["fromage","jambon"]
self.views = ["view1", "view2"]
self.log = True
# Impossible to test as the main directory is notthe same for the exec and the test
# class Test_initLogFile(unittest.TestCase):
#
# @classmethod
# def setUpClass(cls):
# cls.fakeArgs = FakeArg()
# cls.timestr = time.strftime("%Y_%m_%d-%H_%M")
#
# def test_initLogFile(cls):
# cls.timestr = time.strftime("%Y_%m_%d-%H_%M")
# execution.initLogFile(cls.fakeArgs)
# cls.assertIn("zrtTap", os.listdir("mutliview_platform/Results"), "Database directory not created")
# cls.assertIn("started_"+cls.timestr, os.listdir("mutliview_platform/Results/zrtTap"),"experimentation dir not created")
# cls.assertIn(cls.timestr + "-" + ''.join(cls.fakeArgs.CL_type) + "-" + "_".join(
# cls.fakeArgs.views) + "-" + cls.fakeArgs.name + "-LOG.log", os.listdir("mutliview_platform/Results/zrtTap/"+"started_"+cls.timestr), "logfile was not created")
#
# @classmethod
# def tearDownClass(cls):
# shutil.rmtree("multiview_platform/Results/zrtTap")
# pass
class Test_genSplits(unittest.TestCase): class Test_genSplits(unittest.TestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment