diff --git a/multiview_platform/MonoMultiViewClassifiers/ExecClassif.py b/multiview_platform/MonoMultiViewClassifiers/ExecClassif.py index 63588616f00f5841c310c4e58eb5b09214e14151..6519874e25f81d98d4aaa9cc9534b5b8d28ef623 100644 --- a/multiview_platform/MonoMultiViewClassifiers/ExecClassif.py +++ b/multiview_platform/MonoMultiViewClassifiers/ExecClassif.py @@ -344,15 +344,9 @@ def execClassif(arguments): directory = execution.initLogFile(args) randomState = execution.initRandomState(args.randomState, directory) - if statsIter > 1: - statsIterRandomStates = [np.random.RandomState(randomState.randint(500)) for _ in range(statsIter)] - else: - statsIterRandomStates = [randomState] + statsIterRandomStates = execution.initStatsIterRandomStates(statsIter,randomState) - if args.name not in ["Fake", "Plausible"]: - getDatabase = getattr(DB, "getClassicDB" + args.type[1:]) - else: - getDatabase = getattr(DB, "get" + args.name + "DB" + args.type[1:]) + getDatabase = execution.getDatabaseFunction(args.name,args.type) DATASET, LABELS_DICTIONARY = getDatabase(args.views, args.pathF, args.name, args.CL_nbClass, args.CL_classes, randomState, args.full) diff --git a/multiview_platform/MonoMultiViewClassifiers/utils/execution.py b/multiview_platform/MonoMultiViewClassifiers/utils/execution.py index 097914255da877277a6daae647e9f51947d72b5f..ff26d7e1695c98eb2ce8a46e5cf213b5e1b475dd 100644 --- a/multiview_platform/MonoMultiViewClassifiers/utils/execution.py +++ b/multiview_platform/MonoMultiViewClassifiers/utils/execution.py @@ -8,6 +8,8 @@ import logging import sklearn +from . import GetMultiviewDb as DB + def parseTheArgs(arguments): """Used to parse the args entered by the user""" @@ -220,7 +222,14 @@ def parseTheArgs(arguments): def initRandomState(randomStateArg, directory): - """Used to init a random state and multiple if needed (multicore)""" + """ + Used to init a random state. + If no randomState is specified, it will take a 'random' seed. + If the arg is a string containing only numbers, it will be converted in an int to gen a seed. + If the arg is a string with letters, it must be a path to a pickled random state file that will be loaded. + The function will also pickle the new random state in a file tobe able to retrieve it later. + Tested + """ if randomStateArg is None: randomState = np.random.RandomState(randomStateArg) else: @@ -236,26 +245,33 @@ def initRandomState(randomStateArg, directory): return randomState +def initStatsIterRandomStates(statsIter, randomState): + """Used to init multiple random states if needed because of multiple statsIter""" + if statsIter > 1: + statsIterRandomStates = [np.random.RandomState(randomState.randint(5000)) for _ in range(statsIter)] + else: + statsIterRandomStates = [randomState] + return statsIterRandomStates + + +def getDatabaseFunction(name, type): + """Used to get the right databes extraction function according to the type of and it's name""" + if name not in ["Fake", "Plausible"]: + getDatabase = getattr(DB, "getClassicDB" + type[1:]) + else: + getDatabase = getattr(DB, "get" + name + "DB" + type[1:]) + return getDatabase + + def initLogFile(args): """Used to init the directory where the preds will be stored and the log file""" resultDirectory = "../Results/" + args.name + "/started_" + time.strftime("%Y_%m_%d-%H_%M") + "/" - logFileName = time.strftime("%Y_%m_%d-%H_%M_%S") + "-" + ''.join(args.CL_type) + "-" + "_".join( + logFileName = time.strftime("%Y_%m_%d-%H_%M") + "-" + ''.join(args.CL_type) + "-" + "_".join( args.views) + "-" + args.name + "-LOG" - if not os.path.exists(os.path.dirname(resultDirectory + logFileName)): - try: - os.makedirs(os.path.dirname(resultDirectory + logFileName)) - except OSError as exc: - if exc.errno != errno.EEXIST: - raise + if os.path.exists(os.path.dirname(resultDirectory)): + raise NameError("The result dir already exists, wait 1 min and retry") logFile = resultDirectory + logFileName - if os.path.isfile(logFile + ".log"): - for i in range(1, 20): - testFileName = logFileName + "-" + str(i) + ".log" - if not (os.path.isfile(resultDirectory + testFileName)): - logFile = resultDirectory + testFileName - break - else: - logFile += ".log" + logFile += ".log" logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', filename=logFile, level=logging.DEBUG, filemode='w') if args.log: diff --git a/multiview_platform/Tests/Test_utils/test_GetMultiviewDB.py b/multiview_platform/Tests/Test_utils/test_GetMultiviewDB.py index a8d7628cf112cf3ee213570a6d4687c1c135d777..0a2e379d067da8279362aabe6f3dce8dcb4c8ecb 100644 --- a/multiview_platform/Tests/Test_utils/test_GetMultiviewDB.py +++ b/multiview_platform/Tests/Test_utils/test_GetMultiviewDB.py @@ -31,8 +31,6 @@ class Test_copyhdf5Dataset(unittest.TestCase): @classmethod def tearDownClass(cls): os.remove("multiview_platform/Tests/temp_tests/test_copy.hdf5") - # for dir in os.listdir("multiview_platform/Tests/temp_tests"):print(dir) - # import pdb;pdb.set_trace() os.rmdir("multiview_platform/Tests/temp_tests") @@ -126,8 +124,7 @@ class Test_filterLabels(unittest.TestCase): cls.availableLabelsNames, cls.askedLabelsNames) exception = catcher.exception - # cls.assertTrue("Asked more labels than available in the dataset. Available labels are : test_label_0, test_label_1, test_label_2, test_label_3" in exception) - # + def test_asked_all_labels(cls): cls.askedLabelsNamesSet = {"test_label_0", "test_label_1", "test_label_2", "test_label_3"} cls.askedLabelsNames = ["test_label_0", "test_label_1", "test_label_2", "test_label_3"] diff --git a/multiview_platform/Tests/Test_utils/test_execution.py b/multiview_platform/Tests/Test_utils/test_execution.py index c44eee6f1355f75e7757e64053e4d14753413fc7..6e34c649c9b0a499987243cd1edebbfcebfe2ec8 100644 --- a/multiview_platform/Tests/Test_utils/test_execution.py +++ b/multiview_platform/Tests/Test_utils/test_execution.py @@ -3,6 +3,8 @@ import argparse import os import h5py import numpy as np +import shutil +import time from sklearn.model_selection import StratifiedShuffleSplit @@ -19,6 +21,52 @@ class Test_parseTheArgs(unittest.TestCase): # print args +class Test_initStatsIterRandomStates(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.randomState = np.random.RandomState(42) + cls.statsIter = 1 + + def test_one_statiter(cls): + cls.state = cls.randomState.get_state()[1] + statsIterRandomStates = execution.initStatsIterRandomStates(cls.statsIter, cls.randomState) + np.testing.assert_array_equal(statsIterRandomStates[0].get_state()[1], cls.state) + + def test_multiple_iter(cls): + cls.statsIter = 3 + statsIterRandomStates = execution.initStatsIterRandomStates(cls.statsIter, cls.randomState) + cls.assertAlmostEqual(len(statsIterRandomStates), 3) + cls.assertNotEqual(statsIterRandomStates[0].randint(5000), statsIterRandomStates[1].randint(5000)) + cls.assertNotEqual(statsIterRandomStates[0].randint(5000), statsIterRandomStates[2].randint(5000)) + cls.assertNotEqual(statsIterRandomStates[2].randint(5000), statsIterRandomStates[1].randint(5000)) + +class Test_getDatabaseFunction(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.name = "zrtTap" + cls.type = ".csv" + + def test_simple(cls): + getDB = execution.getDatabaseFunction(cls.name, cls.type) + from ...MonoMultiViewClassifiers.utils.GetMultiviewDb import getClassicDBcsv + cls.assertEqual(getDB, getClassicDBcsv) + + def test_hdf5(cls): + cls.type = ".hdf5" + getDB = execution.getDatabaseFunction(cls.name, cls.type) + from ...MonoMultiViewClassifiers.utils.GetMultiviewDb import getClassicDBhdf5 + cls.assertEqual(getDB, getClassicDBhdf5) + + def test_plausible_hdf5(cls): + cls.name = "Plausible" + cls.type = ".hdf5" + getDB = execution.getDatabaseFunction(cls.name, cls.type) + from ...MonoMultiViewClassifiers.utils.GetMultiviewDb import getPlausibleDBhdf5 + cls.assertEqual(getDB, getPlausibleDBhdf5) + + class Test_initRandomState(unittest.TestCase): def setUp(self): @@ -43,10 +91,34 @@ class Test_initRandomState(unittest.TestCase): pickled_randomState.beta(1,100,100)) -class Test_initLogFile(unittest.TestCase): - - def test_initLogFile(self): - pass +class FakeArg(): + + def __init__(self): + self.name = "zrtTap" + self.CL_type = ["fromage","jambon"] + self.views = ["view1", "view2"] + self.log = True + +# Impossible to test as the main directory is notthe same for the exec and the test +# class Test_initLogFile(unittest.TestCase): +# +# @classmethod +# def setUpClass(cls): +# cls.fakeArgs = FakeArg() +# cls.timestr = time.strftime("%Y_%m_%d-%H_%M") +# +# def test_initLogFile(cls): +# cls.timestr = time.strftime("%Y_%m_%d-%H_%M") +# execution.initLogFile(cls.fakeArgs) +# cls.assertIn("zrtTap", os.listdir("mutliview_platform/Results"), "Database directory not created") +# cls.assertIn("started_"+cls.timestr, os.listdir("mutliview_platform/Results/zrtTap"),"experimentation dir not created") +# cls.assertIn(cls.timestr + "-" + ''.join(cls.fakeArgs.CL_type) + "-" + "_".join( +# cls.fakeArgs.views) + "-" + cls.fakeArgs.name + "-LOG.log", os.listdir("mutliview_platform/Results/zrtTap/"+"started_"+cls.timestr), "logfile was not created") +# +# @classmethod +# def tearDownClass(cls): +# shutil.rmtree("multiview_platform/Results/zrtTap") +# pass class Test_genSplits(unittest.TestCase):