diff --git a/multiview_platform/MonoMultiViewClassifiers/Monoview/Additions/QarBoostUtils.py b/multiview_platform/MonoMultiViewClassifiers/Monoview/Additions/QarBoostUtils.py index 4e8fa6c18c076a049189418949ad2d3d7eaea373..983ed265499205e81bdae4f3071eec95d6cf12c7 100644 --- a/multiview_platform/MonoMultiViewClassifiers/Monoview/Additions/QarBoostUtils.py +++ b/multiview_platform/MonoMultiViewClassifiers/Monoview/Additions/QarBoostUtils.py @@ -64,7 +64,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): for k in range(min(n-1, self.n_max_iterations-1 if self.n_max_iterations is not None else np.inf)): # Print dynamically the step and the error of the current classifier - print("{}/{}, eps :{}".format(k+2, self.n_max_iterations, self.voter_perfs[-1]), end="\r") + print("Resp. bound : {}, {}/{}, eps :{}".format(self.respected_bound, k+2, self.n_max_iterations, self.voter_perfs[-1]), end="\r") sol, new_voter_index = self.choose_new_voter(y_kernel_matrix, formatted_y) @@ -125,12 +125,17 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): self.previous_margins.append( np.multiply(y, self.previous_vote)) - self.train_metrics.append( - self.plotted_metric.score(y, np.sign(self.previous_vote))) + train_metric = self.plotted_metric.score(y, np.sign(self.previous_vote)) if self.use_r: - self.bounds.append(self.bounds[-1] * math.sqrt(1 - voter_perf ** 2)) + bound = self.bounds[-1] * math.sqrt(1 - voter_perf ** 2) else: - self.bounds.append(np.prod(np.sqrt(1-4*np.square(0.5-np.array(self.voter_perfs))))) + bound = np.prod(np.sqrt(1-4*np.square(0.5-np.array(self.voter_perfs)))) + + if train_metric > bound: + self.respected_bound = False + + self.train_metrics.append(train_metric) + self.bounds.append(bound) def compute_voter_weight(self, voter_perf): """used to compute the voter's weight according to the specified method (edge or error) """ @@ -207,10 +212,19 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): self.previous_margins.append( np.multiply(y, self.previous_vote)) - self.train_metrics.append( - self.plotted_metric.score(y, np.sign(self.previous_vote))) - self.bounds.append(math.sqrt(1 - r ** 2)) + train_metric =self.plotted_metric.score(y, np.sign(self.previous_vote)) + if self.use_r: + bound = math.sqrt(1 - r ** 2) + else: + bound = np.prod(np.sqrt(1-4*np.square(0.5-np.array(epsilon)))) + + if train_metric > bound: + self.respected_bound = False + + self.train_metrics.append(train_metric) + + self.bounds.append(bound) def format_X_y(self, X, y): """Formats the data : X -the examples- and y -the labels- to be used properly by the algorithm """ @@ -246,6 +260,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): self.bounds = [] self.previous_votes = [] self.previous_margins = [] + self.respected_bound = True def _compute_epsilon(self,y): """Updating the error variable, the old fashioned way uses the whole majority vote to update the error""" @@ -301,9 +316,9 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): indices = [] causes = [] for hypothese_index, hypothese in enumerate(y_kernel_matrix.transpose()): - if (hypothese_index not in self.chosen_columns_ or self.twice_the_same)\ - and set(self.chosen_columns_)!={hypothese_index} \ - and self._is_not_too_wrong(hypothese, y): + if (hypothese_index not in self.chosen_columns_ or self.twice_the_same) \ + and set(self.chosen_columns_)!={hypothese_index} \ + and self._is_not_too_wrong(hypothese, y): w = self._solve_one_weight_min_c(hypothese, y) if w[0] != "break": c_borns.append(self._cbound(w[0])) @@ -433,7 +448,9 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): args_dict = dict((arg_name, str(self.__dict__[arg_name])) for arg_name in self.printed_args_name_list) interpretString += "\n \n With arguments : \n"+u'\u2022 '+ ("\n"+u'\u2022 ').join(['%s: \t%s' % (key, value) - for (key, value) in args_dict.items()]) + for (key, value) in args_dict.items()]) + if not self.respected_bound: + interpretString += "\n\n The bound was not respected" return interpretString