From b5b25f7bc6fd914f13740fab2b180f15f0613b79 Mon Sep 17 00:00:00 2001 From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr> Date: Fri, 7 Feb 2020 16:14:57 +0100 Subject: [PATCH] Hotfix class selection with example_ids --- config_files/config_test.yml | 8 ++++---- .../mono_multi_view_classifiers/utils/dataset.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/config_files/config_test.yml b/config_files/config_test.yml index e40bbc9e..0a541ce0 100644 --- a/config_files/config_test.yml +++ b/config_files/config_test.yml @@ -1,15 +1,15 @@ # The base configuration of the benchmark Base : log: True - name: ["plausible", "koukou"] + name: ["lives_13view", "koukou"] label: "_" type: ".hdf5" views: - pathf: "../data/" + pathf: "/home/baptiste/Documents/Datasets/Alexis/data/" nice: 0 random_state: 42 nb_cores: 1 - full: True + full: False debug: True add_noise: False noise_std: 0.0 @@ -21,7 +21,7 @@ Classification: split: 0.9 nb_folds: 2 nb_class: 2 - classes: + classes: ["EMF", ] type: ["multiview", "monoview"] algos_monoview: ["decision_tree", "adaboost", "random_forest" ] algos_multiview: ["weighted_linear_early_fusion",] diff --git a/multiview_platform/mono_multi_view_classifiers/utils/dataset.py b/multiview_platform/mono_multi_view_classifiers/utils/dataset.py index e7692142..60062bf2 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/dataset.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/dataset.py @@ -291,8 +291,8 @@ class Dataset(): new_dataset_file = h5py.File(dataset_file_path,"w") self.dataset.copy("Metadata", new_dataset_file) if "example_ids" in self.dataset["Metadata"].keys(): - ex_ids = new_dataset_file["Metadata"]["example_ids"] - ex_ids[...] = np.array(self.example_ids)[example_indices].astype(np.dtype("S10")) + del new_dataset_file["Metadata"]["example_ids"] + ex_ids = new_dataset_file["Metadata"].create_dataset("example_ids", data=np.array(self.example_ids)[example_indices].astype(np.dtype("S10"))) else: new_dataset_file["Metadata"].create_dataset("example_ids", (len(self.example_ids), ), -- GitLab