Skip to content
Snippets Groups Projects
Commit 83d4f7cd authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

CQBoost max iter setup

parent 51a49e65
No related branches found
No related tags found
No related merge requests found
......@@ -153,15 +153,15 @@ class ColumnGenerationClassifier(BaseEstimator, ClassifierMixin, BaseBoost):
def get_matrix_to_optimize(self, y_kernel_matrix, w=None):
return y_kernel_matrix[:, self.chosen_columns_]
def _binary_classification_matrix(self, X):
probas = self._collect_probas(X)
predicted_labels = np.argmax(probas, axis=2)
predicted_labels[predicted_labels == 0] = -1
values = np.max(probas, axis=2)
return (predicted_labels * values).T
def _collect_probas(self, X):
return np.asarray([clf.predict_proba(X) for clf in self.estimators_generator.estimators_])
# def _binary_classification_matrix(self, X):
# probas = self._collect_probas(X)
# predicted_labels = np.argmax(probas, axis=2)
# predicted_labels[predicted_labels == 0] = -1
# values = np.max(probas, axis=2)
# return (predicted_labels * values).T
#
# def _collect_probas(self, X):
# return np.asarray([clf.predict_proba(X) for clf in self.estimators_generator.estimators_])
def _restricted_master_problem(self, previous_w=None, previous_alpha=None):
n_examples, n_hypotheses = self.matrix_to_optimize.shape
......
......@@ -7,16 +7,17 @@ import os
class CQBoost(ColumnGenerationClassifier, BaseMonoviewClassifier):
def __init__(self, random_state=None, mu=0.01, epsilon=1e-06, n_stumps=1, **kwargs):
def __init__(self, random_state=None, mu=0.01, epsilon=1e-06, n_stumps=1, n_max_iterations=100, **kwargs):
super(CQBoost, self).__init__(
random_state=random_state,
mu=mu,
epsilon=epsilon,
estimators_generator="Stumps"
estimators_generator="Stumps",
n_max_iterations=100
)
self.param_names = ["mu", "epsilon", "n_stumps", "random_state"]
self.param_names = ["mu", "epsilon", "n_stumps", "random_state", "n_max_iterations"]
self.distribs = [CustomUniform(loc=0.5, state=1.0, multiplier="e-"),
CustomRandint(low=1, high=15, multiplier="e-"), [n_stumps], [random_state]]
CustomRandint(low=1, high=15, multiplier="e-"), [n_stumps], [random_state], [n_max_iterations]]
self.classed_params = []
self.weird_strings = {}
self.n_stumps = n_stumps
......@@ -51,7 +52,8 @@ def formatCmdArgs(args):
"""Used to format kwargs for the parsed args"""
kwargsDict = {"mu": args.CQB_mu,
"epsilon": args.CQB_epsilon,
"n_stumps":args.CQB_stumps}
"n_stumps":args.CQB_stumps,
"n_max_iterations":args.CQB_n_iter}
return kwargsDict
......
......@@ -7,16 +7,18 @@ import os
class CQBoostTree(ColumnGenerationClassifier, BaseMonoviewClassifier):
def __init__(self, random_state=None, mu=0.01, epsilon=1e-06, n_stumps=1, max_depth=2, **kwargs):
def __init__(self, random_state=None, mu=0.01, epsilon=1e-06, n_stumps=1, max_depth=2, n_max_iterations=100, **kwargs):
print(n_max_iterations)
super(CQBoostTree, self).__init__(
random_state=random_state,
mu=mu,
epsilon=epsilon,
estimators_generator="Trees"
estimators_generator="Trees",
n_max_iterations=n_max_iterations
)
self.param_names = ["mu", "epsilon", "n_stumps", "random_state", "max_depth"]
self.param_names = ["mu", "epsilon", "n_stumps", "random_state", "max_depth", "n_max_iterations"]
self.distribs = [CustomUniform(loc=0.5, state=1.0, multiplier="e-"),
CustomRandint(low=1, high=15, multiplier="e-"), [n_stumps], [random_state], [max_depth]]
CustomRandint(low=1, high=15, multiplier="e-"), [n_stumps], [random_state], [max_depth], [n_max_iterations]]
self.classed_params = []
self.weird_strings = {}
self.n_stumps = n_stumps
......@@ -52,7 +54,8 @@ def formatCmdArgs(args):
kwargsDict = {"mu": args.CQBT_mu,
"epsilon": args.CQBT_epsilon,
"n_stumps":args.CQBT_trees,
"max_depth":args.CQBT_max_depth}
"max_depth":args.CQBT_max_depth,
"n_max_iterations":args.CQBT_n_iter}
return kwargsDict
......
......@@ -53,7 +53,6 @@ class SCMPregenTree(scm, BaseMonoviewClassifier, PregenClassifier):
def predict(self, X):
pregen_X, _ = self.pregen_voters(X,)
list_files = os.listdir(".")
print(list_files)
a = int(self.random_state.randint(0, 10000))
if "pregen_x"+str(a)+".csv" in list_files:
a = int(np.random.randint(0, 10000))
......
......@@ -221,6 +221,10 @@ def parseTheArgs(arguments):
action='store',
help='Set the number of stumps for CQBoost',
default=1)
groupCQBoost.add_argument('--CQB_n_iter', metavar='INT', type=int,
action='store',
help='Set the maximum number of iteration in CQBoost',
default=None)
......@@ -305,6 +309,11 @@ def parseTheArgs(arguments):
action='store',
help='Set the number of stumps for CQBoost',
default=2)
groupCQBoostTree.add_argument('--CQBT_n_iter', metavar='INT', type=int,
action='store',
help='Set the maximum number of iteration in CQBoostTree',
default=None)
groupSCMPregenTree = parser.add_argument_group('SCMPregenTree arguments')
groupSCMPregenTree.add_argument('--SCPT_max_rules', metavar='INT', type=int,
action='store',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment