Skip to content
Snippets Groups Projects
Commit 8912f819 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

CQ-Boost monocore modifs

parent f6a627ca
No related branches found
No related tags found
No related merge requests found
......@@ -453,7 +453,7 @@ def execClassif(arguments):
monoviewAlgos = args.CL_algos_monoview
multiviewAlgos = args.CL_algos_multiview
directory = execution.initLogFile(args.name, args.views, args.CL_type, args.log, args.debug)
directory = execution.initLogFile(args.name, args.views, args.CL_type, args.log, args.debug, args.label)
randomState = execution.initRandomState(args.randomState, directory)
statsIterRandomStates = execution.initStatsIterRandomStates(statsIter,randomState)
......
......@@ -3,6 +3,7 @@ from ..Monoview.Additions.CQBoostUtils import ColumnGenerationClassifier
from ..Monoview.Additions.BoostUtils import getInterpretBase
import numpy as np
import os
class CQBoost(ColumnGenerationClassifier, BaseMonoviewClassifier):
......@@ -17,6 +18,19 @@ class CQBoost(ColumnGenerationClassifier, BaseMonoviewClassifier):
CustomRandint(low=1, high=15, multiplier="e-")]
self.classed_params = []
self.weird_strings = {}
if "nbCores" not in kwargs:
self.nbCores = 1
else:
self.nbCores = kwargs["nbCores"]
def fit(self, X, y):
if self.nbCores == 1:
pass
super(CQBoost, self).fit(X,y)
if self.nbCores == 1:
# os.environ['OMP_NUM_THREADS'] = num_threads
pass
def canProbas(self):
"""Used to know if the classifier can return label probabilities"""
......
......@@ -21,6 +21,9 @@ def parseTheArgs(arguments):
groupStandard.add_argument('-log', action='store_true', help='Use option to activate logging to console')
groupStandard.add_argument('--name', metavar='STRING', action='store', help='Name of Database (default: %(default)s)',
default='Plausible')
groupStandard.add_argument('--label', metavar='STRING', action='store',
help='Labeling the results directory (default: %(default)s)',
default='')
groupStandard.add_argument('--type', metavar='STRING', action='store',
help='Type of database : .hdf5 or .csv (default: %(default)s)',
default='.hdf5')
......@@ -366,7 +369,7 @@ def getDatabaseFunction(name, type):
return getDatabase
def initLogFile(name, views, CL_type, log, debug):
def initLogFile(name, views, CL_type, log, debug, label):
r"""Used to init the directory where the preds will be stored and the log file.
First this function will check if the result directory already exists (only one per minute is allowed).
......@@ -390,9 +393,9 @@ def initLogFile(name, views, CL_type, log, debug):
Reference to the main results directory for the benchmark.
"""
if debug:
resultDirectory = "../Results/" + name + "/debug_started_" + time.strftime("%Y_%m_%d-%H_%M_%S") + "/"
resultDirectory = "../Results/" + name + "/debug_started_" + time.strftime("%Y_%m_%d-%H_%M_%S") + "_" + label + "/"
else:
resultDirectory = "../Results/" + name + "/started_" + time.strftime("%Y_%m_%d-%H_%M") + "/"
resultDirectory = "../Results/" + name + "/started_" + time.strftime("%Y_%m_%d-%H_%M") +"_" + label + "/"
logFileName = time.strftime("%Y_%m_%d-%H_%M") + "-" + ''.join(CL_type) + "-" + "_".join(
views) + "-" + name + "-LOG"
if os.path.exists(os.path.dirname(resultDirectory)):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment