Skip to content
Snippets Groups Projects
Commit 7d5fbfc2 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Before tests

parent 300fe29b
No related branches found
No related tags found
No related merge requests found
......@@ -72,7 +72,6 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
"twice_the_same",
"c_bound_choice", "random_start",
"n_stumps", "use_r", "c_bound_sol"]
self.matrix_compute = False
def set_params(self, **params):
self.self_complemented = params["self_complemented"]
......@@ -242,7 +241,6 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
"""THis initialization corressponds to the first round of boosting with equal weights for each examples and the voter chosen by it's margin."""
self.example_weights = self._initialize_alphas(m).reshape((m, 1))
# self.previous_margins.append(np.multiply(y, y))
self.example_weights_.append(self.example_weights)
if self.random_start:
first_voter_index = self.random_state.choice(
......@@ -369,24 +367,6 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
pseudo_h_values[self.chosen_columns_] = ma.masked
return np.argmax(pseudo_h_values), [0]
def _is_not_too_wrong(self, hypothese, y):
"""Check if the weighted margin is better than random"""
if self.c_bound_sol:
return np.sum(hypothese) > 0
else:
print(np.average(hypothese.reshape(y.shape), weights=self.example_weights))
quit()
weighted_margin = np.average(hypothese.reshape(y.shape), weights=self.example_weights)#ondes matrix, axis=0
return weighted_margin > 0
def get_possible(self, y_kernel_matrix, y):
"""Get all the indices of the hypothesis that are good enough to be chosen"""
possibleIndices = []
for hypIndex, hypothese in enumerate(np.transpose(y_kernel_matrix)):
if self._is_not_too_wrong(hypothese, y):
possibleIndices.append(hypIndex)
return np.array(possibleIndices)
def _find_new_voter(self, y_kernel_matrix, y):
"""Here, we solve the two_voters_mincq_problem for each potential new voter,
and select the one that has the smallest minimum"""
......
......@@ -12,7 +12,7 @@ class CGreed(ColumnGenerationClassifierQar, BaseMonoviewClassifier):
twice_the_same=True,
c_bound_choice=True,
random_start=False,
n_stumps_per_attribute=1,
n_stumps_per_attribute=10,
use_r=True,
c_bound_sol=True
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment