From 349c8ed16f79d756f9768125a880760456bcb421 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Tue, 28 Jan 2020 11:10:59 +0100
Subject: [PATCH] Added information per view control

---
 generator/update_baptiste.py | 32 ++++++++++++++++++++++++++++----
 1 file changed, 28 insertions(+), 4 deletions(-)

diff --git a/generator/update_baptiste.py b/generator/update_baptiste.py
index d26fa60..726e4ed 100644
--- a/generator/update_baptiste.py
+++ b/generator/update_baptiste.py
@@ -6,6 +6,7 @@ from math import ceil, floor
 import pandas as pd
 import h5py
 
+
 class MultiviewDatasetGenetator():
 
     def __init__(self, n_samples=100, n_views=2, n_classes=2,
@@ -19,6 +20,7 @@ class MultiviewDatasetGenetator():
                                 standard_deviation=2,
                                 weights=None,
                                 flip_y=0.0,
+                                n_informative_weights=None,
                                 random_state=42, config_path=None):
         if config_path is not None:
             with open(config_path) as config_file:
@@ -38,6 +40,7 @@ class MultiviewDatasetGenetator():
             self.standard_deviation = standard_deviation
             self.weights = weights
             self.flip_y = flip_y
+            self.n_informative_weights = n_informative_weights
             if isinstance(random_state, np.random.RandomState):
                 self.random_state = random_state
             elif isinstance(random_state, int):
@@ -97,20 +100,39 @@ class MultiviewDatasetGenetator():
                                    flip_y=self.flip_y,
                                    class_sep=self.n_clusters_per_class * self.class_sep_factor,
                                    random_state=self.random_state, shuffle=False)
+        self.informative_indices = np.arange(self.dim_Z)[:self.n_informative]
         I_q = np.arange(self.Z.shape[1])
         meta_I_v = []
         self.results = []
-        for view in range(n_views):
+        for view_index in range(n_views):
+            if self.n_informative_weights is not None and len(self.n_informative_weights)==n_views:
+                if self.n_informative*self.n_informative_weights[view_index] > d_v[view_index]:
+                    n_informative_view = int(self.n_informative*self.n_informative_weights[view_index])
+                    d_v[view_index] = n_informative_view
+                    I_v = self.random_state.choice(self.informative_indices,
+                                                   size=n_informative_view,
+                                                   replace=False)
+                else:
+                    n_informative_view = int(self.n_informative*self.n_informative_weights[view_index])
+                    print(n_informative_view, d_v)
+                    informative_indices = self.random_state.choice(self.informative_indices,
+                                                   size=n_informative_view,
+                                                   replace=False)
+                    I_v = np.concatenate((informative_indices,
+                                         self.random_state.choice(np.arange(self.dim_Z)[self.n_informative:],
+                                                                  size=d_v[view_index]-n_informative_view,
+                                                                  replace=False)))
+            else:
             # choice d_v[view] numeros of Z columns uniformly from I_q
-            I_v = self.random_state.choice(I_q, size=d_v[view],
-                                   replace=False)  # tirage dans I_q sans remise de taille d_v[view]
+                I_v = self.random_state.choice(I_q, size=d_v[view_index],
+                                               replace=False)  # tirage dans I_q sans remise de taille d_v[view]
             meta_I_v += list(I_v)
             # projection of Z along the columns in I_v
             X_v = self.projection( I_v)
             self.results.append((X_v, I_v))
             # remove R*d_v[view] columns numeros of I_v form I_q
             elements_to_remove = self.random_state.choice(I_v,
-                                                  size=floor(self.R * d_v[view]),
+                                                  size=floor(self.R * d_v[view_index]),
                                                   replace=False)  # tirage dans I_v sans remise de taille floor(R*d_v[view])
             I_q = np.setdiff1d(I_q,
                                elements_to_remove)  # I_q less elements from elements_to_remove
@@ -250,6 +272,7 @@ if __name__=="__main__":
     flip_y = 0.00  # Ratio of label noise
     random_state = 42
     weights = None # The proportions of examples in each class
+    n_informative_weights = np.array([0.1,0.5,0.9,0.4])
 
     path = "/home/baptiste/Documents/Datasets/Generated/metrics_dset/"
     name = "metrics"
@@ -269,6 +292,7 @@ if __name__=="__main__":
                                                     standard_deviation=standard_deviation,
                                                     flip_y=flip_y,
                                                     weights=weights,
+                                                    n_informative_weights=n_informative_weights,
                                                     random_state=random_state)
 
     multiview_generator.generate()
-- 
GitLab