Skip to content
Snippets Groups Projects
Commit 4e11c380 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added fig

parent 27743786
No related branches found
No related tags found
No related merge requests found
Pipeline #5285 passed
docs/source/_static/fig_rec.png

243 KiB

...@@ -65,7 +65,6 @@ class MultiViewSubProblemsGenerator: ...@@ -65,7 +65,6 @@ class MultiViewSubProblemsGenerator:
if config_file is not None: if config_file is not None:
args = get_config_from_file(config_file) args = get_config_from_file(config_file)
self.__init__(**args) self.__init__(**args)
print(self.sub_problem_types)
else: else:
self.rs = init_random_state(random_state) self.rs = init_random_state(random_state)
self.n_samples = n_samples self.n_samples = n_samples
...@@ -264,8 +263,6 @@ class MultiViewSubProblemsGenerator: ...@@ -264,8 +263,6 @@ class MultiViewSubProblemsGenerator:
(self.n_examples_per_class[class_index] * (1 - confusion)).astype( (self.n_examples_per_class[class_index] * (1 - confusion)).astype(
int) int)
for class_index, confusion in enumerate(self.error_matrix)] for class_index, confusion in enumerate(self.error_matrix)]
print(self.n_well_described)
print(2000*0.6/3)
self.n_misdescribed = [(self.n_examples_per_class[class_index] - self.n_misdescribed = [(self.n_examples_per_class[class_index] -
self.n_well_described[class_index]) self.n_well_described[class_index])
for class_index in range(self.n_classes)] for class_index in range(self.n_classes)]
...@@ -391,20 +388,21 @@ class MultiViewSubProblemsGenerator: ...@@ -391,20 +388,21 @@ class MultiViewSubProblemsGenerator:
return self.n_samples * self.latent_size_mult return self.n_samples * self.latent_size_mult
def _gen_indices_for(self, name="redundancy", error_mat_fun=lambda x: 1 - x): def _gen_indices_for(self, name="redundancy", error_mat_fun=lambda x: 1 - x):
quantity = getattr(self, name) quantities = getattr(self, name)
indices = getattr(self, name + "_indices") indices = getattr(self, name + "_indices")
if (np.repeat(quantity, self.n_views, if (np.repeat(quantities, self.n_views,
axis=1) > error_mat_fun(self.error_matrix)).any(): axis=1) > error_mat_fun(self.error_matrix)).any():
raise ValueError( raise ValueError(
"{} ({}) must be at least equal to the lowest accuracy rate " "{} ({}) must be at least equal to the lowest accuracy rate "
"of all the confusion matrix ({}).".format(name, "of all the confusion matrix ({}).".format(name,
quantity, np.min(error_mat_fun(self.error_matrix), axis=1))) quantities, np.min(error_mat_fun(self.error_matrix), axis=1)))
else: else:
for class_index, redundancy in enumerate(quantity): for class_index, quantity in enumerate(quantities):
indices[class_index] = self.rs.choice( indices[class_index] = self.rs.choice(
self.available_init_indices[class_index], self.available_init_indices[class_index],
size=int( size=int(
self.n_examples_per_class[class_index] * redundancy)) self.n_examples_per_class[class_index] * quantity),
replace=False)
self._update_example_indices( self._update_example_indices(
indices[class_index], name, indices[class_index], name,
class_index) class_index)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment