Skip to content
Snippets Groups Projects
Commit 84042eec authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Pre-reformat

parent 6f8efba6
No related branches found
No related tags found
No related merge requests found
...@@ -2,12 +2,10 @@ import scipy ...@@ -2,12 +2,10 @@ import scipy
import logging import logging
import numpy as np import numpy as np
import numpy.ma as ma import numpy.ma as ma
from collections import defaultdict
import math import math
from sklearn.utils.validation import check_is_fitted from sklearn.utils.validation import check_is_fitted
from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import BaseEstimator, ClassifierMixin
import time import time
import matplotlib.pyplot as plt
from .BoostUtils import StumpsClassifiersGenerator, sign, BaseBoost, \ from .BoostUtils import StumpsClassifiersGenerator, sign, BaseBoost, \
getInterpretBase, get_accuracy_graph, TreeClassifiersGenerator getInterpretBase, get_accuracy_graph, TreeClassifiersGenerator
...@@ -22,7 +20,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -22,7 +20,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
c_bound_choice=True, random_start=True, c_bound_choice=True, random_start=True,
n_stumps=1, use_r=True, c_bound_sol=True, n_stumps=1, use_r=True, c_bound_sol=True,
plotted_metric=Metrics.zero_one_loss, save_train_data=True, plotted_metric=Metrics.zero_one_loss, save_train_data=True,
test_graph=True, mincq_tracking=False): test_graph=True, mincq_tracking=True):
super(ColumnGenerationClassifierQar, self).__init__() super(ColumnGenerationClassifierQar, self).__init__()
r""" r"""
...@@ -157,6 +155,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -157,6 +155,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
return signs_array return signs_array
def step_predict(self, classification_matrix): def step_predict(self, classification_matrix):
"""Used to predict with each step of the greedy algorithm to analyze its performance increase"""
if classification_matrix.shape != self.train_shape: if classification_matrix.shape != self.train_shape:
self.step_decisions = np.zeros(classification_matrix.shape) self.step_decisions = np.zeros(classification_matrix.shape)
self.mincq_step_decisions = np.zeros(classification_matrix.shape) self.mincq_step_decisions = np.zeros(classification_matrix.shape)
...@@ -200,14 +199,14 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -200,14 +199,14 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
self.train_metrics.append(train_metric) self.train_metrics.append(train_metric)
self.bounds.append(bound) self.bounds.append(bound)
if self.mincq_tracking: if self.mincq_tracking: # Used to compute the optimal c-bound distribution on the chose set
from ...MonoviewClassifiers.MinCQ import MinCqLearner from ...MonoviewClassifiers.MinCQ import MinCqLearner
mincq = MinCqLearner(10e-3, "stumps", n_stumps_per_attribute=1) mincq = MinCqLearner(10e-3, "stumps", n_stumps_per_attribute=1, self_complemented=False)
training_set = self.classification_matrix[:, self.chosen_columns_] training_set = self.classification_matrix[:, self.chosen_columns_]
mincq.fit(training_set, y) mincq.fit(training_set, y)
mincq_pred = mincq.predict(training_set) mincq_pred = mincq.predict(training_set)
self.mincq_learners.append(mincq) self.mincq_learners.append(mincq)
self.mincq_train_metrics.append(self.plotted_metric.score(y, mincq_pred)) self.mincq_train_metrics.append(self.plotted_metric.score(y, change_label_to_minus(mincq_pred)))
self.mincq_weights.append(mincq.majority_vote._weights) self.mincq_weights.append(mincq.majority_vote._weights)
self.mincq_c_bounds.append(mincq.majority_vote.cbound_value(training_set, y.reshape((y.shape[0],)))) self.mincq_c_bounds.append(mincq.majority_vote.cbound_value(training_set, y.reshape((y.shape[0],))))
...@@ -243,7 +242,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -243,7 +242,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
(self.n_total_examples, 1)) (self.n_total_examples, 1))
def choose_new_voter(self, y_kernel_matrix, formatted_y): def choose_new_voter(self, y_kernel_matrix, formatted_y):
"""Used to chhoose the voter according to the specified criterion (margin or C-Bound""" """Used to choose the voter according to the specified criterion (margin or C-Bound"""
if self.c_bound_choice: if self.c_bound_choice:
sol, new_voter_index = self._find_new_voter(y_kernel_matrix, sol, new_voter_index = self._find_new_voter(y_kernel_matrix,
formatted_y) formatted_y)
...@@ -505,7 +504,6 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -505,7 +504,6 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
step_mincq_test_metrics.append(self.plotted_metric.score(y_test, step_mincq_test_metrics.append(self.plotted_metric.score(y_test,
self.mincq_step_decisions[:, self.mincq_step_decisions[:,
step_index])) step_index]))
# step_mincq_test_metrics = np.array(step_mincq_test_metrics)
np.savetxt(directory + "mincq_step_test_metrics.csv", np.savetxt(directory + "mincq_step_test_metrics.csv",
step_mincq_test_metrics, step_mincq_test_metrics,
delimiter=',') delimiter=',')
......
...@@ -86,6 +86,8 @@ class MinCqLearner(BaseEstimator, ClassifierMixin): ...@@ -86,6 +86,8 @@ class MinCqLearner(BaseEstimator, ClassifierMixin):
if (np.unique(y)!= [-1,1]).any(): if (np.unique(y)!= [-1,1]).any():
y_reworked = np.copy(y) y_reworked = np.copy(y)
y_reworked[np.where(y_reworked==0)] = -1 y_reworked[np.where(y_reworked==0)] = -1
else:
y_reworked = y
assert self.voters_type in ['stumps', 'kernel', 'manual'], "MinCqLearner: voters_type must be 'stumps', 'kernel' or 'manual'" assert self.voters_type in ['stumps', 'kernel', 'manual'], "MinCqLearner: voters_type must be 'stumps', 'kernel' or 'manual'"
......
...@@ -390,7 +390,7 @@ def parseTheArgs(arguments): ...@@ -390,7 +390,7 @@ def parseTheArgs(arguments):
type=int, type=int,
action='store', action='store',
help='Number of stumps inthe pregenerated dataset', help='Number of stumps inthe pregenerated dataset',
default=2) default=3)
groupLasso = parser.add_argument_group('Lasso arguments') groupLasso = parser.add_argument_group('Lasso arguments')
groupLasso.add_argument('--LA_n_iter', metavar='INT', type=int, groupLasso.add_argument('--LA_n_iter', metavar='INT', type=int,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment