Skip to content
Snippets Groups Projects
Commit d26c9f30 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added compute_voter_weight and update_info_containers

parent d3109743
Branches
Tags
No related merge requests found
...@@ -88,26 +88,14 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -88,26 +88,14 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
self.break_cause = " epsilon was too small." self.break_cause = " epsilon was too small."
break break
if self.use_r: self.compute_voter_weight(r, epsilon)
self.q = 0.5*math.log((1+r)/(1-r))
else:
self.q = math.log((1 - epsilon) / epsilon)
self.weights_.append(self.q)
# Update the distribution on the examples. # Update the distribution on the examples.
self._update_example_weights(formatted_y) self.update_example_weights(formatted_y)
self.example_weights_.append(self.example_weights)
# Update the "previous vote" to prepare for the next iteration # Update the "previous vote" to prepare for the next iteration
self.previous_vote = np.matmul(self.classification_matrix[:, self.chosen_columns_], self.update_info_containers(formatted_y, r, k)
np.array(self.weights_).reshape((k + 2, 1))).reshape((m, 1))
self.previous_votes.append(self.previous_vote)
self.previous_margins.append(np.multiply(formatted_y, self.previous_vote))
self.train_metrics.append(self.plotted_metric.score(formatted_y, np.sign(self.previous_vote)))
# self.bounds.append(np.prod(np.sqrt(1-4*np.square(0.5-np.array(self.epsilons)))))
self.bounds.append(self.bounds[-1]*math.sqrt(1-r**2))
self.nb_opposed_voters = self.check_opposed_voters() self.nb_opposed_voters = self.check_opposed_voters()
self.estimators_generator.estimators_ = self.estimators_generator.estimators_[self.chosen_columns_] self.estimators_generator.estimators_ = self.estimators_generator.estimators_[self.chosen_columns_]
...@@ -134,6 +122,27 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -134,6 +122,27 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
self.predict_time = end - start self.predict_time = end - start
return signs_array return signs_array
def update_info_containers(self, y, r, k):
self.example_weights_.append(self.example_weights)
self.previous_vote = np.matmul(
self.classification_matrix[:, self.chosen_columns_],
np.array(self.weights_).reshape((k + 2, 1))).reshape((self.n_total_examples, 1))
self.previous_votes.append(self.previous_vote)
self.previous_margins.append(
np.multiply(y, self.previous_vote))
self.train_metrics.append(
self.plotted_metric.score(y, np.sign(self.previous_vote)))
self.bounds.append(self.bounds[-1] * math.sqrt(1 - r ** 2))
# self.bounds.append(np.prod(np.sqrt(1-4*np.square(0.5-np.array(self.epsilons)))))
def compute_voter_weight(self, r, epsilon):
if self.use_r:
self.q = 0.5 * math.log((1 + r) / (1 - r))
else:
self.q = math.log((1 - epsilon) / epsilon)
self.weights_.append(self.q)
def compute_voter_perf(self, formatted_y): def compute_voter_perf(self, formatted_y):
epsilon = self._compute_epsilon(formatted_y) epsilon = self._compute_epsilon(formatted_y)
self.epsilons.append(epsilon) self.epsilons.append(epsilon)
...@@ -186,7 +195,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -186,7 +195,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
self.weights_.append(self.q) self.weights_.append(self.q)
# Update the distribution on the examples. # Update the distribution on the examples.
self._update_example_weights(y) self.update_example_weights(y)
self.example_weights_.append(self.example_weights) self.example_weights_.append(self.example_weights)
self.previous_margins.append( self.previous_margins.append(
...@@ -241,7 +250,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -241,7 +250,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
r = np.average(ones_matrix, weights=self.example_weights, axis=0) r = np.average(ones_matrix, weights=self.example_weights, axis=0)
return r return r
def _update_example_weights(self, y): def update_example_weights(self, y):
"""Old fashioned exaple weights update uses the whole majority vote, the other way uses only the last voter.""" """Old fashioned exaple weights update uses the whole majority vote, the other way uses only the last voter."""
new_weights = self.example_weights.reshape((self.n_total_examples, 1))*np.exp(-self.q*np.multiply(y,self.new_voter)) new_weights = self.example_weights.reshape((self.n_total_examples, 1))*np.exp(-self.q*np.multiply(y,self.new_voter))
self.example_weights = new_weights/np.sum(new_weights) self.example_weights = new_weights/np.sum(new_weights)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment