Select Git revision
-
Baptiste Bauvin authoredBaptiste Bauvin authored
CQBoost.py 2.68 KiB
from ..Monoview.MonoviewUtils import CustomUniform, CustomRandint, BaseMonoviewClassifier
from ..Monoview.Additions.CQBoostUtils import ColumnGenerationClassifier
from ..Monoview.Additions.BoostUtils import getInterpretBase
import numpy as np
import os
class CQBoost(ColumnGenerationClassifier, BaseMonoviewClassifier):
def __init__(self, random_state=None, mu=0.01, epsilon=1e-06, n_stumps=10, **kwargs):
super(CQBoost, self).__init__(
random_state=random_state,
mu=mu,
epsilon=epsilon
)
self.param_names = ["mu", "epsilon"]
self.distribs = [CustomUniform(loc=0.5, state=1.0, multiplier="e-"),
CustomRandint(low=1, high=15, multiplier="e-")]
self.classed_params = []
self.weird_strings = {}
self.n_stumps = n_stumps
if "nbCores" not in kwargs:
self.nbCores = 1
else:
self.nbCores = kwargs["nbCores"]
def fit(self, X, y):
if self.nbCores == 1:
pass
super(CQBoost, self).fit(X,y)
if self.nbCores == 1:
# os.environ['OMP_NUM_THREADS'] = num_threads
pass
def canProbas(self):
"""Used to know if the classifier can return label probabilities"""
return True
def getInterpret(self, directory, y_test):
np.savetxt(directory + "train_metrics.csv", self.train_metrics, delimiter=',')
np.savetxt(directory + "c_bounds.csv", self.c_bounds,
delimiter=',')
np.savetxt(directory + "y_test_step.csv", self.step_decisions,
delimiter=',')
step_metrics = []
for step_index in range(self.step_decisions.shape[1] - 1):
step_metrics.append(self.plotted_metric.score(y_test,
self.step_decisions[:,
step_index]))
step_metrics = np.array(step_metrics)
np.savetxt(directory + "step_test_metrics.csv", step_metrics,
delimiter=',')
return getInterpretBase(self, directory, "CQBoost", self.weights_, y_test)
def formatCmdArgs(args):
"""Used to format kwargs for the parsed args"""
kwargsDict = {"mu": args.CQB_mu,
"epsilon": args.CQB_epsilon,
"n_stumps":args.CQB_stumps}
return kwargsDict
def paramsToSet(nIter, randomState):
"""Used for weighted linear early fusion to generate random search sets"""
paramsSet = []
for _ in range(nIter):
paramsSet.append({"mu": 10**-randomState.uniform(0.5, 1.5),
"epsilon": 10**-randomState.randint(1, 15)})
return paramsSet