Commit 4e11c380 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added fig

parent 27743786
Pipeline #5285 passed with stages
in 2 minutes and 46 seconds
...@@ -65,7 +65,6 @@ class MultiViewSubProblemsGenerator: ...@@ -65,7 +65,6 @@ class MultiViewSubProblemsGenerator:
if config_file is not None: if config_file is not None:
args = get_config_from_file(config_file) args = get_config_from_file(config_file)
self.__init__(**args) self.__init__(**args)
print(self.sub_problem_types)
else: else:
self.rs = init_random_state(random_state) self.rs = init_random_state(random_state)
self.n_samples = n_samples self.n_samples = n_samples
...@@ -264,8 +263,6 @@ class MultiViewSubProblemsGenerator: ...@@ -264,8 +263,6 @@ class MultiViewSubProblemsGenerator:
(self.n_examples_per_class[class_index] * (1 - confusion)).astype( (self.n_examples_per_class[class_index] * (1 - confusion)).astype(
int) int)
for class_index, confusion in enumerate(self.error_matrix)] for class_index, confusion in enumerate(self.error_matrix)]
print(self.n_well_described)
print(2000*0.6/3)
self.n_misdescribed = [(self.n_examples_per_class[class_index] - self.n_misdescribed = [(self.n_examples_per_class[class_index] -
self.n_well_described[class_index]) self.n_well_described[class_index])
for class_index in range(self.n_classes)] for class_index in range(self.n_classes)]
...@@ -391,20 +388,21 @@ class MultiViewSubProblemsGenerator: ...@@ -391,20 +388,21 @@ class MultiViewSubProblemsGenerator:
return self.n_samples * self.latent_size_mult return self.n_samples * self.latent_size_mult
def _gen_indices_for(self, name="redundancy", error_mat_fun=lambda x: 1 - x): def _gen_indices_for(self, name="redundancy", error_mat_fun=lambda x: 1 - x):
quantity = getattr(self, name) quantities = getattr(self, name)
indices = getattr(self, name + "_indices") indices = getattr(self, name + "_indices")
if (np.repeat(quantity, self.n_views, if (np.repeat(quantities, self.n_views,
axis=1) > error_mat_fun(self.error_matrix)).any(): axis=1) > error_mat_fun(self.error_matrix)).any():
raise ValueError( raise ValueError(
"{} ({}) must be at least equal to the lowest accuracy rate " "{} ({}) must be at least equal to the lowest accuracy rate "
"of all the confusion matrix ({}).".format(name, "of all the confusion matrix ({}).".format(name,
quantity, np.min(error_mat_fun(self.error_matrix), axis=1))) quantities, np.min(error_mat_fun(self.error_matrix), axis=1)))
else: else:
for class_index, redundancy in enumerate(quantity): for class_index, quantity in enumerate(quantities):
indices[class_index] = self.rs.choice( indices[class_index] = self.rs.choice(
self.available_init_indices[class_index], self.available_init_indices[class_index],
size=int( size=int(
self.n_examples_per_class[class_index] * redundancy)) self.n_examples_per_class[class_index] * quantity),
replace=False)
self._update_example_indices( self._update_example_indices(
indices[class_index], name, indices[class_index], name,
class_index) class_index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment