Skip to content
Snippets Groups Projects
Commit d0a8eb33 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Gradient boosting name

parent 40c7ee5a
Branches
Tags
No related merge requests found
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import GradientBoostingClassifier from sklearn.ensemble import GradientBoostingClassifier
import time
import numpy as np
from ..Monoview.MonoviewUtils import CustomRandint, BaseMonoviewClassifier from ..Monoview.MonoviewUtils import CustomRandint, BaseMonoviewClassifier
from .. import Metrics
from ..Monoview.Additions.BoostUtils import get_accuracy_graph
# Author-Info # Author-Info
__author__ = "Baptiste Bauvin" __author__ = "Baptiste Bauvin"
__status__ = "Prototype" # Production, Development, Prototype __status__ = "Prototype" # Production, Development, Prototype
class CustomDecisionTree(DecisionTreeClassifier):
def predict(self, X, check_input=True):
y_pred = super(CustomDecisionTree, self).predict(X, check_input=check_input)
return y_pred.reshape((y_pred.shape[0], 1)).astype(float)
class GradientBoosting(GradientBoostingClassifier, BaseMonoviewClassifier): class GradientBoosting(GradientBoostingClassifier, BaseMonoviewClassifier):
def __init__(self, random_state=None, loss="exponential", max_depth=1.0, def __init__(self, random_state=None, loss="exponential", max_depth=1.0,
n_estimators=100, init=DecisionTreeClassifier(max_depth=1), **kwargs): n_estimators=100,
init=CustomDecisionTree(max_depth=1),
**kwargs):
super(GradientBoosting, self).__init__( super(GradientBoosting, self).__init__(
loss=loss, loss=loss,
max_depth=max_depth, max_depth=max_depth,
...@@ -23,6 +34,33 @@ class GradientBoosting(GradientBoostingClassifier, BaseMonoviewClassifier): ...@@ -23,6 +34,33 @@ class GradientBoosting(GradientBoostingClassifier, BaseMonoviewClassifier):
self.classed_params = [] self.classed_params = []
self.distribs = [CustomRandint(low=50, high=500),] self.distribs = [CustomRandint(low=50, high=500),]
self.weird_strings = {} self.weird_strings = {}
self.plotted_metric = Metrics.zero_one_loss
self.plotted_metric_name = "zero_one_loss"
self.step_predictions = None
def fit(self, X, y, sample_weight=None, monitor=None):
begin = time.time()
super(GradientBoosting, self).fit(X, y, sample_weight=sample_weight)
end = time.time()
self.train_time = end - begin
self.train_shape = X.shape
self.base_predictions = np.array(
[estim[0].predict(X) for estim in self.estimators_])
self.metrics = np.array(
[self.plotted_metric.score(pred, y) for pred in self.staged_predict(X)])
# self.bounds = np.array([np.prod(
# np.sqrt(1 - 4 * np.square(0.5 - self.estimator_errors_[:i + 1]))) for i
# in range(self.estimator_errors_.shape[0])])
return self
def predict(self, X):
begin = time.time()
pred = super(GradientBoosting, self).predict(X)
end = time.time()
self.pred_time = end - begin
if X.shape != self.train_shape:
self.step_predictions = np.array([step_pred for step_pred in self.staged_predict(X)])
return pred
def canProbas(self): def canProbas(self):
"""Used to know if the classifier can return label probabilities""" """Used to know if the classifier can return label probabilities"""
...@@ -30,6 +68,14 @@ class GradientBoosting(GradientBoostingClassifier, BaseMonoviewClassifier): ...@@ -30,6 +68,14 @@ class GradientBoosting(GradientBoostingClassifier, BaseMonoviewClassifier):
def getInterpret(self, directory, y_test): def getInterpret(self, directory, y_test):
interpretString = "" interpretString = ""
interpretString += self.getFeatureImportance(directory)
step_test_metrics = np.array([self.plotted_metric.score(y_test, step_pred) for step_pred in self.step_predictions])
get_accuracy_graph(step_test_metrics, "AdaboostClassic", directory + "test_metrics.png",
self.plotted_metric_name, set="test")
get_accuracy_graph(self.metrics, "AdaboostClassic", directory+"metrics.png", self.plotted_metric_name)
np.savetxt(directory + "test_metrics.csv", step_test_metrics, delimiter=',')
np.savetxt(directory + "train_metrics.csv", self.metrics, delimiter=',')
np.savetxt(directory + "times.csv", np.array([self.train_time, self.pred_time]), delimiter=',')
return interpretString return interpretString
......
...@@ -186,16 +186,22 @@ def parseTheArgs(arguments): ...@@ -186,16 +186,22 @@ def parseTheArgs(arguments):
groupCGreed.add_argument('--CGR_n_iter', metavar='INT', type=int, action='store', groupCGreed.add_argument('--CGR_n_iter', metavar='INT', type=int, action='store',
help='Set the n_max_iterations parameter for CGreed', default=100) help='Set the n_max_iterations parameter for CGreed', default=100)
groupCGDesc = parser.add_argument_group('CGDesc arguments') groupGradientBoosting = parser.add_argument_group('CGDesc arguments')
groupCGDesc.add_argument('--CGD_stumps', metavar='INT', type=int, groupGradientBoosting.add_argument('--CGD_stumps', metavar='INT', type=int,
action='store', action='store',
help='Set the n_stumps_per_attribute parameter for CGreed', help='Set the n_stumps_per_attribute parameter for CGreed',
default=1) default=1)
groupCGDesc.add_argument('--CGD_n_iter', metavar='INT', type=int, groupGradientBoosting.add_argument('--CGD_n_iter', metavar='INT', type=int,
action='store', action='store',
help='Set the n_max_iterations parameter for CGreed', help='Set the n_max_iterations parameter for CGreed',
default=100) default=100)
groupGradientBoosting = parser.add_argument_group('Gradient Boosting arguments')
groupGradientBoosting.add_argument('--GB_n_est', metavar='INT', type=int,
action='store',
help='Set the n_estimators_parameter for Gradient Boosting',
default=1)
groupQarBoostv3 = parser.add_argument_group('QarBoostv3 arguments') groupQarBoostv3 = parser.add_argument_group('QarBoostv3 arguments')
groupQarBoostv3.add_argument('--QarB3_mu', metavar='FLOAT', type=float, action='store', groupQarBoostv3.add_argument('--QarB3_mu', metavar='FLOAT', type=float, action='store',
help='Set the mu parameter for QarBoostv3', default=0.001) help='Set the mu parameter for QarBoostv3', default=0.001)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment