Skip to content
Snippets Groups Projects
Commit 83d4f7cd authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

CQBoost max iter setup

parent 51a49e65
No related branches found
No related tags found
No related merge requests found
...@@ -153,15 +153,15 @@ class ColumnGenerationClassifier(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -153,15 +153,15 @@ class ColumnGenerationClassifier(BaseEstimator, ClassifierMixin, BaseBoost):
def get_matrix_to_optimize(self, y_kernel_matrix, w=None): def get_matrix_to_optimize(self, y_kernel_matrix, w=None):
return y_kernel_matrix[:, self.chosen_columns_] return y_kernel_matrix[:, self.chosen_columns_]
def _binary_classification_matrix(self, X): # def _binary_classification_matrix(self, X):
probas = self._collect_probas(X) # probas = self._collect_probas(X)
predicted_labels = np.argmax(probas, axis=2) # predicted_labels = np.argmax(probas, axis=2)
predicted_labels[predicted_labels == 0] = -1 # predicted_labels[predicted_labels == 0] = -1
values = np.max(probas, axis=2) # values = np.max(probas, axis=2)
return (predicted_labels * values).T # return (predicted_labels * values).T
#
def _collect_probas(self, X): # def _collect_probas(self, X):
return np.asarray([clf.predict_proba(X) for clf in self.estimators_generator.estimators_]) # return np.asarray([clf.predict_proba(X) for clf in self.estimators_generator.estimators_])
def _restricted_master_problem(self, previous_w=None, previous_alpha=None): def _restricted_master_problem(self, previous_w=None, previous_alpha=None):
n_examples, n_hypotheses = self.matrix_to_optimize.shape n_examples, n_hypotheses = self.matrix_to_optimize.shape
......
...@@ -7,16 +7,17 @@ import os ...@@ -7,16 +7,17 @@ import os
class CQBoost(ColumnGenerationClassifier, BaseMonoviewClassifier): class CQBoost(ColumnGenerationClassifier, BaseMonoviewClassifier):
def __init__(self, random_state=None, mu=0.01, epsilon=1e-06, n_stumps=1, **kwargs): def __init__(self, random_state=None, mu=0.01, epsilon=1e-06, n_stumps=1, n_max_iterations=100, **kwargs):
super(CQBoost, self).__init__( super(CQBoost, self).__init__(
random_state=random_state, random_state=random_state,
mu=mu, mu=mu,
epsilon=epsilon, epsilon=epsilon,
estimators_generator="Stumps" estimators_generator="Stumps",
n_max_iterations=100
) )
self.param_names = ["mu", "epsilon", "n_stumps", "random_state"] self.param_names = ["mu", "epsilon", "n_stumps", "random_state", "n_max_iterations"]
self.distribs = [CustomUniform(loc=0.5, state=1.0, multiplier="e-"), self.distribs = [CustomUniform(loc=0.5, state=1.0, multiplier="e-"),
CustomRandint(low=1, high=15, multiplier="e-"), [n_stumps], [random_state]] CustomRandint(low=1, high=15, multiplier="e-"), [n_stumps], [random_state], [n_max_iterations]]
self.classed_params = [] self.classed_params = []
self.weird_strings = {} self.weird_strings = {}
self.n_stumps = n_stumps self.n_stumps = n_stumps
...@@ -51,7 +52,8 @@ def formatCmdArgs(args): ...@@ -51,7 +52,8 @@ def formatCmdArgs(args):
"""Used to format kwargs for the parsed args""" """Used to format kwargs for the parsed args"""
kwargsDict = {"mu": args.CQB_mu, kwargsDict = {"mu": args.CQB_mu,
"epsilon": args.CQB_epsilon, "epsilon": args.CQB_epsilon,
"n_stumps":args.CQB_stumps} "n_stumps":args.CQB_stumps,
"n_max_iterations":args.CQB_n_iter}
return kwargsDict return kwargsDict
......
...@@ -7,16 +7,18 @@ import os ...@@ -7,16 +7,18 @@ import os
class CQBoostTree(ColumnGenerationClassifier, BaseMonoviewClassifier): class CQBoostTree(ColumnGenerationClassifier, BaseMonoviewClassifier):
def __init__(self, random_state=None, mu=0.01, epsilon=1e-06, n_stumps=1, max_depth=2, **kwargs): def __init__(self, random_state=None, mu=0.01, epsilon=1e-06, n_stumps=1, max_depth=2, n_max_iterations=100, **kwargs):
print(n_max_iterations)
super(CQBoostTree, self).__init__( super(CQBoostTree, self).__init__(
random_state=random_state, random_state=random_state,
mu=mu, mu=mu,
epsilon=epsilon, epsilon=epsilon,
estimators_generator="Trees" estimators_generator="Trees",
n_max_iterations=n_max_iterations
) )
self.param_names = ["mu", "epsilon", "n_stumps", "random_state", "max_depth"] self.param_names = ["mu", "epsilon", "n_stumps", "random_state", "max_depth", "n_max_iterations"]
self.distribs = [CustomUniform(loc=0.5, state=1.0, multiplier="e-"), self.distribs = [CustomUniform(loc=0.5, state=1.0, multiplier="e-"),
CustomRandint(low=1, high=15, multiplier="e-"), [n_stumps], [random_state], [max_depth]] CustomRandint(low=1, high=15, multiplier="e-"), [n_stumps], [random_state], [max_depth], [n_max_iterations]]
self.classed_params = [] self.classed_params = []
self.weird_strings = {} self.weird_strings = {}
self.n_stumps = n_stumps self.n_stumps = n_stumps
...@@ -52,7 +54,8 @@ def formatCmdArgs(args): ...@@ -52,7 +54,8 @@ def formatCmdArgs(args):
kwargsDict = {"mu": args.CQBT_mu, kwargsDict = {"mu": args.CQBT_mu,
"epsilon": args.CQBT_epsilon, "epsilon": args.CQBT_epsilon,
"n_stumps":args.CQBT_trees, "n_stumps":args.CQBT_trees,
"max_depth":args.CQBT_max_depth} "max_depth":args.CQBT_max_depth,
"n_max_iterations":args.CQBT_n_iter}
return kwargsDict return kwargsDict
......
...@@ -53,7 +53,6 @@ class SCMPregenTree(scm, BaseMonoviewClassifier, PregenClassifier): ...@@ -53,7 +53,6 @@ class SCMPregenTree(scm, BaseMonoviewClassifier, PregenClassifier):
def predict(self, X): def predict(self, X):
pregen_X, _ = self.pregen_voters(X,) pregen_X, _ = self.pregen_voters(X,)
list_files = os.listdir(".") list_files = os.listdir(".")
print(list_files)
a = int(self.random_state.randint(0, 10000)) a = int(self.random_state.randint(0, 10000))
if "pregen_x"+str(a)+".csv" in list_files: if "pregen_x"+str(a)+".csv" in list_files:
a = int(np.random.randint(0, 10000)) a = int(np.random.randint(0, 10000))
......
...@@ -221,6 +221,10 @@ def parseTheArgs(arguments): ...@@ -221,6 +221,10 @@ def parseTheArgs(arguments):
action='store', action='store',
help='Set the number of stumps for CQBoost', help='Set the number of stumps for CQBoost',
default=1) default=1)
groupCQBoost.add_argument('--CQB_n_iter', metavar='INT', type=int,
action='store',
help='Set the maximum number of iteration in CQBoost',
default=None)
...@@ -305,6 +309,11 @@ def parseTheArgs(arguments): ...@@ -305,6 +309,11 @@ def parseTheArgs(arguments):
action='store', action='store',
help='Set the number of stumps for CQBoost', help='Set the number of stumps for CQBoost',
default=2) default=2)
groupCQBoostTree.add_argument('--CQBT_n_iter', metavar='INT', type=int,
action='store',
help='Set the maximum number of iteration in CQBoostTree',
default=None)
groupSCMPregenTree = parser.add_argument_group('SCMPregenTree arguments') groupSCMPregenTree = parser.add_argument_group('SCMPregenTree arguments')
groupSCMPregenTree.add_argument('--SCPT_max_rules', metavar='INT', type=int, groupSCMPregenTree.add_argument('--SCPT_max_rules', metavar='INT', type=int,
action='store', action='store',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment