Skip to content
Snippets Groups Projects
Commit 7e8222c6 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added multiple kwargs per clf, to be tested

parent 19946122
No related branches found
No related tags found
No related merge requests found
......@@ -6,6 +6,7 @@ import pkgutil
import time
import matplotlib
import itertools
import numpy as np
from joblib import Parallel, delayed
......@@ -157,7 +158,7 @@ def initMonoviewExps(benchmark, viewsDictionary, nbClass, kwargsInit):
for viewName, viewIndex in viewsDictionary.items():
for classifier in benchmark["Monoview"]:
if multiple_args(classifier, kwargsInit):
argumentDictionaries["Monoview"] += gen_multiple_args_dictionnaries(nbClass, kwargsInit)
argumentDictionaries["Monoview"] += gen_multiple_args_dictionnaries(nbClass, kwargsInit, classifier, viewName, viewIndex)
else:
arguments = {
"args": {classifier + "KWARGS": kwargsInit[
......@@ -168,14 +169,37 @@ def initMonoviewExps(benchmark, viewsDictionary, nbClass, kwargsInit):
return argumentDictionaries
def multiple_args(classifier, kwargsInit):
listed_args = [type(value) == list and len(value)>1 for key, value in kwargsInit[classifier + "KWARGSInit"].items()]
listed_args = [type(value) == list and len(value)>1 for key, value in
kwargsInit[classifier + "KWARGSInit"].items()]
if True in listed_args:
return True
else:
return False
def gen_multiple_args_dictionnaries(nbClass, kwargsInit):
def gen_multiple_kwargs_combinations(clKWARGS):
values = list(clKWARGS.values())
listed_values = [[_] if type(_) is not list else _ for _ in values]
values_cartesian_prod = [_ for _ in itertools.product(*listed_values)]
keys = clKWARGS.keys()
kwargs_combination = [dict((key, value) for key, value in zip(keys, values))
for values in values_cartesian_prod]
return kwargs_combination
def gen_multiple_args_dictionnaries(nbClass, kwargsInit,
classifier, viewName, viewIndex):
multiple_kwargs_list = gen_multiple_kwargs_combinations(kwargsInit[classifier + "KWARGSInit"])
multiple_kwargs_dict = dict(
(classifier+"_"+"_".join(map(str,list(dictionary.values()))), dictionary)
for dictionary in multiple_kwargs_list)
args_dictionnaries = [{
"args": {classifier_name + "KWARGS": arguments,
"feat": viewName,
"CL_type": classifier_name,
"nbClass": nbClass},
"viewIndex": viewIndex}
for classifier_name, arguments in multiple_kwargs_dict.items()]
return args_dictionnaries
def initMonoviewKWARGS(args, classifiersNames):
......
......@@ -77,7 +77,8 @@ def ExecMonoview(directory, X, Y, name, labelsNames, classificationIndices,
logging.debug("Done:\t Determine Train/Test split")
logging.debug("Start:\t Generate classifier args")
classifierModule = getattr(MonoviewClassifiers, CL_type)
classifierModuleName = CL_type.split("_")[0]
classifierModule = getattr(MonoviewClassifiers, classifierModuleName)
clKWARGS, testFoldsPreds = getHPs(classifierModule, hyperParamSearch,
nIter, CL_type, X_train, y_train,
randomState, outputFileName,
......@@ -85,7 +86,7 @@ def ExecMonoview(directory, X, Y, name, labelsNames, classificationIndices,
logging.debug("Done:\t Generate classifier args")
logging.debug("Start:\t Training")
classifier = getattr(classifierModule, CL_type)(randomState, **clKWARGS)
classifier = getattr(classifierModule, classifierModuleName)(randomState, **clKWARGS)
classifier.fit(X_train, y_train) # NB_CORES=nbCores,
logging.debug("Done:\t Training")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment