Skip to content
Snippets Groups Projects
Commit 7d5fbfc2 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Before tests

parent 300fe29b
Branches
Tags
No related merge requests found
......@@ -72,7 +72,6 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
"twice_the_same",
"c_bound_choice", "random_start",
"n_stumps", "use_r", "c_bound_sol"]
self.matrix_compute = False
def set_params(self, **params):
self.self_complemented = params["self_complemented"]
......@@ -242,7 +241,6 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
"""THis initialization corressponds to the first round of boosting with equal weights for each examples and the voter chosen by it's margin."""
self.example_weights = self._initialize_alphas(m).reshape((m, 1))
# self.previous_margins.append(np.multiply(y, y))
self.example_weights_.append(self.example_weights)
if self.random_start:
first_voter_index = self.random_state.choice(
......@@ -369,24 +367,6 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
pseudo_h_values[self.chosen_columns_] = ma.masked
return np.argmax(pseudo_h_values), [0]
def _is_not_too_wrong(self, hypothese, y):
"""Check if the weighted margin is better than random"""
if self.c_bound_sol:
return np.sum(hypothese) > 0
else:
print(np.average(hypothese.reshape(y.shape), weights=self.example_weights))
quit()
weighted_margin = np.average(hypothese.reshape(y.shape), weights=self.example_weights)#ondes matrix, axis=0
return weighted_margin > 0
def get_possible(self, y_kernel_matrix, y):
"""Get all the indices of the hypothesis that are good enough to be chosen"""
possibleIndices = []
for hypIndex, hypothese in enumerate(np.transpose(y_kernel_matrix)):
if self._is_not_too_wrong(hypothese, y):
possibleIndices.append(hypIndex)
return np.array(possibleIndices)
def _find_new_voter(self, y_kernel_matrix, y):
"""Here, we solve the two_voters_mincq_problem for each potential new voter,
and select the one that has the smallest minimum"""
......
......@@ -12,7 +12,7 @@ class CGreed(ColumnGenerationClassifierQar, BaseMonoviewClassifier):
twice_the_same=True,
c_bound_choice=True,
random_start=False,
n_stumps_per_attribute=1,
n_stumps_per_attribute=10,
use_r=True,
c_bound_sol=True
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment