diff --git a/docs/source/_static/fig_rec.png b/docs/source/_static/fig_rec.png new file mode 100644 index 0000000000000000000000000000000000000000..8b1dde52ecc5797568e18ec7dd5efe9ddb01a54e Binary files /dev/null and b/docs/source/_static/fig_rec.png differ diff --git a/multiview_generator/multiple_sub_problems.py b/multiview_generator/multiple_sub_problems.py index 553372eef119ed49110b7a66a8b9e0c81d95f59d..e92c8cab5e448f0a3fcbd4c76041022c40702779 100644 --- a/multiview_generator/multiple_sub_problems.py +++ b/multiview_generator/multiple_sub_problems.py @@ -65,7 +65,6 @@ class MultiViewSubProblemsGenerator: if config_file is not None: args = get_config_from_file(config_file) self.__init__(**args) - print(self.sub_problem_types) else: self.rs = init_random_state(random_state) self.n_samples = n_samples @@ -264,8 +263,6 @@ class MultiViewSubProblemsGenerator: (self.n_examples_per_class[class_index] * (1 - confusion)).astype( int) for class_index, confusion in enumerate(self.error_matrix)] - print(self.n_well_described) - print(2000*0.6/3) self.n_misdescribed = [(self.n_examples_per_class[class_index] - self.n_well_described[class_index]) for class_index in range(self.n_classes)] @@ -391,20 +388,21 @@ class MultiViewSubProblemsGenerator: return self.n_samples * self.latent_size_mult def _gen_indices_for(self, name="redundancy", error_mat_fun=lambda x: 1 - x): - quantity = getattr(self, name) + quantities = getattr(self, name) indices = getattr(self, name + "_indices") - if (np.repeat(quantity, self.n_views, + if (np.repeat(quantities, self.n_views, axis=1) > error_mat_fun(self.error_matrix)).any(): raise ValueError( "{} ({}) must be at least equal to the lowest accuracy rate " "of all the confusion matrix ({}).".format(name, - quantity, np.min(error_mat_fun(self.error_matrix), axis=1))) + quantities, np.min(error_mat_fun(self.error_matrix), axis=1))) else: - for class_index, redundancy in enumerate(quantity): + for class_index, quantity in enumerate(quantities): indices[class_index] = self.rs.choice( self.available_init_indices[class_index], size=int( - self.n_examples_per_class[class_index] * redundancy)) + self.n_examples_per_class[class_index] * quantity), + replace=False) self._update_example_indices( indices[class_index], name, class_index)