Skip to content
Snippets Groups Projects
Commit 349c8ed1 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added information per view control

parent 1c3a3cce
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,7 @@ from math import ceil, floor ...@@ -6,6 +6,7 @@ from math import ceil, floor
import pandas as pd import pandas as pd
import h5py import h5py
class MultiviewDatasetGenetator(): class MultiviewDatasetGenetator():
def __init__(self, n_samples=100, n_views=2, n_classes=2, def __init__(self, n_samples=100, n_views=2, n_classes=2,
...@@ -19,6 +20,7 @@ class MultiviewDatasetGenetator(): ...@@ -19,6 +20,7 @@ class MultiviewDatasetGenetator():
standard_deviation=2, standard_deviation=2,
weights=None, weights=None,
flip_y=0.0, flip_y=0.0,
n_informative_weights=None,
random_state=42, config_path=None): random_state=42, config_path=None):
if config_path is not None: if config_path is not None:
with open(config_path) as config_file: with open(config_path) as config_file:
...@@ -38,6 +40,7 @@ class MultiviewDatasetGenetator(): ...@@ -38,6 +40,7 @@ class MultiviewDatasetGenetator():
self.standard_deviation = standard_deviation self.standard_deviation = standard_deviation
self.weights = weights self.weights = weights
self.flip_y = flip_y self.flip_y = flip_y
self.n_informative_weights = n_informative_weights
if isinstance(random_state, np.random.RandomState): if isinstance(random_state, np.random.RandomState):
self.random_state = random_state self.random_state = random_state
elif isinstance(random_state, int): elif isinstance(random_state, int):
...@@ -97,12 +100,31 @@ class MultiviewDatasetGenetator(): ...@@ -97,12 +100,31 @@ class MultiviewDatasetGenetator():
flip_y=self.flip_y, flip_y=self.flip_y,
class_sep=self.n_clusters_per_class * self.class_sep_factor, class_sep=self.n_clusters_per_class * self.class_sep_factor,
random_state=self.random_state, shuffle=False) random_state=self.random_state, shuffle=False)
self.informative_indices = np.arange(self.dim_Z)[:self.n_informative]
I_q = np.arange(self.Z.shape[1]) I_q = np.arange(self.Z.shape[1])
meta_I_v = [] meta_I_v = []
self.results = [] self.results = []
for view in range(n_views): for view_index in range(n_views):
if self.n_informative_weights is not None and len(self.n_informative_weights)==n_views:
if self.n_informative*self.n_informative_weights[view_index] > d_v[view_index]:
n_informative_view = int(self.n_informative*self.n_informative_weights[view_index])
d_v[view_index] = n_informative_view
I_v = self.random_state.choice(self.informative_indices,
size=n_informative_view,
replace=False)
else:
n_informative_view = int(self.n_informative*self.n_informative_weights[view_index])
print(n_informative_view, d_v)
informative_indices = self.random_state.choice(self.informative_indices,
size=n_informative_view,
replace=False)
I_v = np.concatenate((informative_indices,
self.random_state.choice(np.arange(self.dim_Z)[self.n_informative:],
size=d_v[view_index]-n_informative_view,
replace=False)))
else:
# choice d_v[view] numeros of Z columns uniformly from I_q # choice d_v[view] numeros of Z columns uniformly from I_q
I_v = self.random_state.choice(I_q, size=d_v[view], I_v = self.random_state.choice(I_q, size=d_v[view_index],
replace=False) # tirage dans I_q sans remise de taille d_v[view] replace=False) # tirage dans I_q sans remise de taille d_v[view]
meta_I_v += list(I_v) meta_I_v += list(I_v)
# projection of Z along the columns in I_v # projection of Z along the columns in I_v
...@@ -110,7 +132,7 @@ class MultiviewDatasetGenetator(): ...@@ -110,7 +132,7 @@ class MultiviewDatasetGenetator():
self.results.append((X_v, I_v)) self.results.append((X_v, I_v))
# remove R*d_v[view] columns numeros of I_v form I_q # remove R*d_v[view] columns numeros of I_v form I_q
elements_to_remove = self.random_state.choice(I_v, elements_to_remove = self.random_state.choice(I_v,
size=floor(self.R * d_v[view]), size=floor(self.R * d_v[view_index]),
replace=False) # tirage dans I_v sans remise de taille floor(R*d_v[view]) replace=False) # tirage dans I_v sans remise de taille floor(R*d_v[view])
I_q = np.setdiff1d(I_q, I_q = np.setdiff1d(I_q,
elements_to_remove) # I_q less elements from elements_to_remove elements_to_remove) # I_q less elements from elements_to_remove
...@@ -250,6 +272,7 @@ if __name__=="__main__": ...@@ -250,6 +272,7 @@ if __name__=="__main__":
flip_y = 0.00 # Ratio of label noise flip_y = 0.00 # Ratio of label noise
random_state = 42 random_state = 42
weights = None # The proportions of examples in each class weights = None # The proportions of examples in each class
n_informative_weights = np.array([0.1,0.5,0.9,0.4])
path = "/home/baptiste/Documents/Datasets/Generated/metrics_dset/" path = "/home/baptiste/Documents/Datasets/Generated/metrics_dset/"
name = "metrics" name = "metrics"
...@@ -269,6 +292,7 @@ if __name__=="__main__": ...@@ -269,6 +292,7 @@ if __name__=="__main__":
standard_deviation=standard_deviation, standard_deviation=standard_deviation,
flip_y=flip_y, flip_y=flip_y,
weights=weights, weights=weights,
n_informative_weights=n_informative_weights,
random_state=random_state) random_state=random_state)
multiview_generator.generate() multiview_generator.generate()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment