Skip to content
Snippets Groups Projects
Commit 84042eec authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Pre-reformat

parent 6f8efba6
No related branches found
No related tags found
No related merge requests found
......@@ -2,12 +2,10 @@ import scipy
import logging
import numpy as np
import numpy.ma as ma
from collections import defaultdict
import math
from sklearn.utils.validation import check_is_fitted
from sklearn.base import BaseEstimator, ClassifierMixin
import time
import matplotlib.pyplot as plt
from .BoostUtils import StumpsClassifiersGenerator, sign, BaseBoost, \
getInterpretBase, get_accuracy_graph, TreeClassifiersGenerator
......@@ -22,7 +20,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
c_bound_choice=True, random_start=True,
n_stumps=1, use_r=True, c_bound_sol=True,
plotted_metric=Metrics.zero_one_loss, save_train_data=True,
test_graph=True, mincq_tracking=False):
test_graph=True, mincq_tracking=True):
super(ColumnGenerationClassifierQar, self).__init__()
r"""
......@@ -157,6 +155,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
return signs_array
def step_predict(self, classification_matrix):
"""Used to predict with each step of the greedy algorithm to analyze its performance increase"""
if classification_matrix.shape != self.train_shape:
self.step_decisions = np.zeros(classification_matrix.shape)
self.mincq_step_decisions = np.zeros(classification_matrix.shape)
......@@ -200,14 +199,14 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
self.train_metrics.append(train_metric)
self.bounds.append(bound)
if self.mincq_tracking:
if self.mincq_tracking: # Used to compute the optimal c-bound distribution on the chose set
from ...MonoviewClassifiers.MinCQ import MinCqLearner
mincq = MinCqLearner(10e-3, "stumps", n_stumps_per_attribute=1)
mincq = MinCqLearner(10e-3, "stumps", n_stumps_per_attribute=1, self_complemented=False)
training_set = self.classification_matrix[:, self.chosen_columns_]
mincq.fit(training_set, y)
mincq_pred = mincq.predict(training_set)
self.mincq_learners.append(mincq)
self.mincq_train_metrics.append(self.plotted_metric.score(y, mincq_pred))
self.mincq_train_metrics.append(self.plotted_metric.score(y, change_label_to_minus(mincq_pred)))
self.mincq_weights.append(mincq.majority_vote._weights)
self.mincq_c_bounds.append(mincq.majority_vote.cbound_value(training_set, y.reshape((y.shape[0],))))
......@@ -243,7 +242,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
(self.n_total_examples, 1))
def choose_new_voter(self, y_kernel_matrix, formatted_y):
"""Used to chhoose the voter according to the specified criterion (margin or C-Bound"""
"""Used to choose the voter according to the specified criterion (margin or C-Bound"""
if self.c_bound_choice:
sol, new_voter_index = self._find_new_voter(y_kernel_matrix,
formatted_y)
......@@ -505,7 +504,6 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
step_mincq_test_metrics.append(self.plotted_metric.score(y_test,
self.mincq_step_decisions[:,
step_index]))
# step_mincq_test_metrics = np.array(step_mincq_test_metrics)
np.savetxt(directory + "mincq_step_test_metrics.csv",
step_mincq_test_metrics,
delimiter=',')
......
......@@ -86,6 +86,8 @@ class MinCqLearner(BaseEstimator, ClassifierMixin):
if (np.unique(y)!= [-1,1]).any():
y_reworked = np.copy(y)
y_reworked[np.where(y_reworked==0)] = -1
else:
y_reworked = y
assert self.voters_type in ['stumps', 'kernel', 'manual'], "MinCqLearner: voters_type must be 'stumps', 'kernel' or 'manual'"
......
......@@ -390,7 +390,7 @@ def parseTheArgs(arguments):
type=int,
action='store',
help='Number of stumps inthe pregenerated dataset',
default=2)
default=3)
groupLasso = parser.add_argument_group('Lasso arguments')
groupLasso.add_argument('--LA_n_iter', metavar='INT', type=int,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment