From b3dc2153a27119e60903440e3c917f4245cbf852 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Fri, 25 Mar 2022 07:42:35 -0400
Subject: [PATCH] Feat ids

---
 summit/multiview_platform/utils/dataset.py | 57 +++++++++++-----------
 1 file changed, 29 insertions(+), 28 deletions(-)

diff --git a/summit/multiview_platform/utils/dataset.py b/summit/multiview_platform/utils/dataset.py
index 2a33b34b..15175976 100644
--- a/summit/multiview_platform/utils/dataset.py
+++ b/summit/multiview_platform/utils/dataset.py
@@ -14,7 +14,7 @@ from .organization import secure_file_path
  of SuMMIT'''
 
 
-class Dataset():
+class Dataset:
     """
     This is the base class for all the type of multiview datasets of SuMMIT.
     """
@@ -165,11 +165,9 @@ class Dataset():
 
         return selected_label_names
 
-    def gen_feat_id(self):
-        self.feature_ids =  [["ID_" + str(i) for i in
+    def gen_feat_id(self, view_ind):
+        self.feature_ids[view_ind] = ["ID_" + str(i) for i in
                                  range(self.get_v(view_ind).shape[1])]
-                                for view_ind in self.view_dict.values()]
-
 
 
 class RAMDataset(Dataset):
@@ -193,13 +191,14 @@ class RAMDataset(Dataset):
         self.name = name
         self.nb_view = len(self.views)
         self.is_temp = False
-        if feature_ids is not None:
-            feature_ids = [[feature_id if not is_just_number(feature_id)
-                            else "ID_" + feature_id for feature_id in
-                            feat_ids] for feat_ids in feature_ids]
-            self.feature_ids = feature_ids
-        else:
-            self.gen_feat_id()
+        self.feature_ids = [_ for _ in range(self.nb_view)]
+        for view_ind in range(self.nb_view):
+            if feature_ids is not None:
+                self.feature_ids[view_ind] = [feature_id if not is_just_number(feature_id)
+                                else "ID_" + feature_id for feature_id in
+                                feature_ids[view_ind]]
+            else:
+                self.gen_feat_id(view_ind)
 
     def get_view_name(self, view_idx):
         return self.view_names[view_idx]
@@ -377,14 +376,15 @@ class HDF5Dataset(Dataset):
             else:
                 self.sample_ids = ["ID_" + str(i)
                                    for i in range(labels.shape[0])]
-            if feature_ids is not None:
-                feature_ids = [[feature_id if not is_just_number(feature_id)
-                              else "ID_" + feature_id for feature_id in
-                              feat_ids] for feat_ids in feature_ids]
-                self.feature_ids = feature_ids
-            else:
-                self.gen_feat_id()
-
+            self.feature_ids = [_ for _ in range(self.nb_view)]
+            for view_index in range(self.nb_view):
+                if feature_ids is not None:
+                    feat_ids = [feature_id if not is_just_number(feature_id)
+                                  else "ID_" + feature_id for feature_id in
+                                  feature_ids[view_index]]
+                    self.feature_ids = feat_ids
+                else:
+                    self.gen_feat_id(view_index)
 
     def get_v(self, view_index, sample_indices=None):
         """ Extract the view and returns a numpy.ndarray containing the description
@@ -443,6 +443,7 @@ class HDF5Dataset(Dataset):
 
         """
         self.nb_view = self.dataset["Metadata"].attrs["nbView"]
+        self.feature_ids = [_ for _ in range(self.nb_view)]
         self.view_dict = self.get_view_dict()
         self.view_names = [self.dataset["View{}".format(ind)].attrs['name'] for ind in range(self.nb_view)]
         if "sample_ids" in self.dataset["Metadata"].keys():
@@ -454,14 +455,14 @@ class HDF5Dataset(Dataset):
         else:
             self.sample_ids = ["ID_" + str(i) for i in
                                range(self.dataset["Labels"].shape[0])]
-        if "feature_ids" in self.dataset["Metadata"].keys():
-            self.feature_ids = [[feature_id.decode()
-                               if not is_just_number(feature_id.decode())
-                               else "ID_" + feature_id.decode()
-                               for feature_id in feature_ids] for feature_ids in
-                               self.dataset["Metadata"]["feature_ids"]]
-        else:
-           self.gen_feat_id()
+        for view_index in range(self.nb_view):
+            if "feature_ids-View{}".format(view_index) in self.dataset["Metadata"].keys():
+                self.feature_ids[view_index] = [feature_id.decode()
+                                   if not is_just_number(feature_id.decode())
+                                   else "ID_" + feature_id.decode()
+                                       for feature_id in self.dataset["Metadata"]["feature_ids-View{}".format(view_index)]]
+            else:
+               self.gen_feat_id(view_index)
 
     def get_nb_samples(self):
         """
-- 
GitLab