Skip to content
Snippets Groups Projects
Commit b15eb3c5 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Adjustments on training curve with noise

parent 05b4b9ab
No related branches found
No related tags found
No related merge requests found
......@@ -202,12 +202,12 @@ class StumpsClassifiersGenerator(ClassifiersGenerator):
self.estimators_ += [DecisionStumpClassifier(i,
(different[stump_number]+different[
stump_number+1])/2, 1).fit(X, y)
for stump_number in range(nb_different-1)]
for stump_number in range(int(nb_different)-1)]
if self.self_complemented:
self.estimators_ += [DecisionStumpClassifier(i,
(different[stump_number] + different[
stump_number + 1]) / 2, -1).fit(X, y)
for stump_number in range(nb_different-1)]
for stump_number in range(int(nb_different)-1)]
else:
self.estimators_ += [DecisionStumpClassifier(i, minimums[i] + ranges[i] * stump_number, 1).fit(X, y)
for stump_number in range(1, self.n_stumps_per_attribute + 1)
......
......@@ -22,7 +22,7 @@ class ColumnGenerationClassifier(BaseEstimator, ClassifierMixin, BaseBoost):
self.dual_constraint_rhs = dual_constraint_rhs
self.mu = mu
self.train_time = 0
self.plotted_metric = Metrics.accuracy_score
self.plotted_metric = Metrics.zero_one_loss
def fit(self, X, y):
start = time.time()
......
......@@ -502,6 +502,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
np.savetxt(directory+"y_train.csv", self.y_train, delimiter=',')
np.savetxt(directory + "raw_weights.csv", self.raw_weights, delimiter=',')
np.savetxt(directory + "c_bounds.csv", self.c_bounds, delimiter=',')
np.savetxt(directory + "train_metrics.csv", self.train_metrics, delimiter=',')
args_dict = dict(
(arg_name, str(self.__dict__[arg_name])) for arg_name in
self.printed_args_name_list)
......
......@@ -46,6 +46,7 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier):
interpretString += "\n\n Estimator error | Estimator weight\n"
interpretString += "\n".join([str(error) +" | "+ str(weight/sum(self.estimator_weights_)) for error, weight in zip(self.estimator_errors_, self.estimator_weights_)])
get_accuracy_graph(self.metrics, "Adaboost", directory+"metrics.png", self.plotted_metric_name, bounds=list(self.bounds))
np.savetxt(directory + "train_metrics.csv", self.metrics, delimiter=',')
return interpretString
......
......@@ -2,6 +2,7 @@ from ..Monoview.MonoviewUtils import CustomUniform, CustomRandint, BaseMonoviewC
from ..Monoview.Additions.CQBoostUtils import ColumnGenerationClassifier
from ..Monoview.Additions.BoostUtils import getInterpretBase
import numpy as np
class CQBoost(ColumnGenerationClassifier, BaseMonoviewClassifier):
......@@ -22,6 +23,7 @@ class CQBoost(ColumnGenerationClassifier, BaseMonoviewClassifier):
return True
def getInterpret(self, directory, y_test):
np.savetxt(directory + "train_metrics.csv", self.train_metrics, delimiter=',')
return getInterpretBase(self, directory, "CQBoost", self.weights_, y_test)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment