Skip to content
Snippets Groups Projects
Commit 336421c5 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Adjustments on training curve with noise

parent b15eb3c5
Branches
Tags
No related merge requests found
......@@ -31,7 +31,7 @@ def randomizedSearch(X_train, y_train, randomState, outputFileName, classifierMo
nIter = nb_possible_combinations
randomSearch = RandomizedSearchCV(estimator, n_iter=nIter, param_distributions=params_dict, refit=True,
n_jobs=nbCores, scoring=scorer, cv=KFolds, random_state=randomState)
print(estimator)
print(X_train)
detector = randomSearch.fit(X_train, y_train)
bestParams = estimator.genBestParams(detector)
......
......@@ -29,6 +29,7 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier):
self.weird_strings = {"base_estimator": "class_name"}
self.plotted_metric = Metrics.zero_one_loss
self.plotted_metric_name = "zero_one_loss"
self.step_predictions = None
def fit(self, X, y, sample_weight=None):
super(Adaboost, self).fit(X, y, sample_weight=sample_weight)
......@@ -40,12 +41,20 @@ class Adaboost(AdaBoostClassifier, BaseMonoviewClassifier):
"""Used to know if the classifier can return label probabilities"""
return True
def predict(self, X):
super(Adaboost, self).predict(X)
self.step_predictions = np.array([step_pred for step_pred in self.staged_predict(X)])
def getInterpret(self, directory, y_test):
interpretString = ""
interpretString += self.getFeatureImportance(directory)
interpretString += "\n\n Estimator error | Estimator weight\n"
interpretString += "\n".join([str(error) +" | "+ str(weight/sum(self.estimator_weights_)) for error, weight in zip(self.estimator_errors_, self.estimator_weights_)])
step_test_metrics = np.array([self.plotted_metric.score(y_test, step_pred) for step_pred in self.step_predictions])
get_accuracy_graph(step_test_metrics, "Adaboost", directory + "test_metrics.png",
self.plotted_metric_name, set="test")
get_accuracy_graph(self.metrics, "Adaboost", directory+"metrics.png", self.plotted_metric_name, bounds=list(self.bounds))
np.savetxt(directory + "test_metrics.csv", step_test_metrics, delimiter=',')
np.savetxt(directory + "train_metrics.csv", self.metrics, delimiter=',')
return interpretString
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment