From 9dd0172b23ddd2314a607ef898a43332d363a8b2 Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Wed, 27 Nov 2019 08:51:21 -0500
Subject: [PATCH] Removed prints and now saving config file in result directory

---
 .../mono_multi_view_classifiers/exec_classif.py          | 2 +-
 .../weighted_linear_early_fusion.py                      | 3 ---
 .../mono_multi_view_classifiers/utils/configuration.py   | 9 +++++++++
 .../mono_multi_view_classifiers/utils/execution.py       | 5 +++--
 4 files changed, 13 insertions(+), 6 deletions(-)

diff --git a/multiview_platform/mono_multi_view_classifiers/exec_classif.py b/multiview_platform/mono_multi_view_classifiers/exec_classif.py
index 88ae2c87..5a91fb0d 100644
--- a/multiview_platform/mono_multi_view_classifiers/exec_classif.py
+++ b/multiview_platform/mono_multi_view_classifiers/exec_classif.py
@@ -807,7 +807,7 @@ def exec_classif(arguments):
 
             directory = execution.init_log_file(dataset_name, args["Base"]["views"], args["Classification"]["type"],
                                               args["Base"]["log"], args["Base"]["debug"], args["Base"]["label"],
-                                              args["Base"]["res_dir"], args["Base"]["add_noise"], noise_std)
+                                              args["Base"]["res_dir"], args["Base"]["add_noise"], noise_std, args)
             random_state = execution.init_random_state(args["Base"]["random_state"], directory)
             stats_iter_random_states = execution.init_stats_iter_random_states(stats_iter,
                                                                         random_state)
diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py
index e8437154..484b6657 100644
--- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py
+++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py
@@ -31,7 +31,6 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier):
     def __init__(self, random_state=None, view_weights=None,
                  monoview_classifier_name="decision_tree",
                  monoview_classifier_config={}):
-        print(type(view_weights), view_weights)
         super(WeightedLinearEarlyFusion, self).__init__(random_state=random_state)
         self.view_weights = view_weights
         self.monoview_classifier_name = monoview_classifier_name
@@ -84,12 +83,10 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier):
         example_indices, self.view_indices = get_examples_views_indices(dataset,
                                                                         example_indices,
                                                                         view_indices)
-        print(type(self.view_weights))
         if self.view_weights is None:
             self.view_weights = np.ones(len(self.view_indices), dtype=float)
         else:
             self.view_weights = np.array(self.view_weights)
-        print(self.view_weights)
         self.view_weights /= float(np.sum(self.view_weights))
 
         X = self.hdf5_to_monoview(dataset, example_indices)
diff --git a/multiview_platform/mono_multi_view_classifiers/utils/configuration.py b/multiview_platform/mono_multi_view_classifiers/utils/configuration.py
index f297dcf0..0f044c61 100644
--- a/multiview_platform/mono_multi_view_classifiers/utils/configuration.py
+++ b/multiview_platform/mono_multi_view_classifiers/utils/configuration.py
@@ -1,4 +1,5 @@
 import yaml
+import os
 
 
 def get_the_args(path_to_config_file="../config_files/config.yml"):
@@ -18,3 +19,11 @@ def get_the_args(path_to_config_file="../config_files/config.yml"):
     with open(path_to_config_file, 'r') as stream:
         yaml_config = yaml.safe_load(stream)
     return yaml_config
+
+
+def save_config(directory, arguments):
+    """
+    Saves the config file in the result directory.
+    """
+    with open(os.path.join(directory, "config_file.yml"), "w") as stream:
+        yaml.dump(arguments, stream)
\ No newline at end of file
diff --git a/multiview_platform/mono_multi_view_classifiers/utils/execution.py b/multiview_platform/mono_multi_view_classifiers/utils/execution.py
index 7bc2d155..a9168479 100644
--- a/multiview_platform/mono_multi_view_classifiers/utils/execution.py
+++ b/multiview_platform/mono_multi_view_classifiers/utils/execution.py
@@ -8,6 +8,7 @@ import numpy as np
 import sklearn
 
 from . import get_multiview_db as DB
+from ..utils.configuration import save_config
 
 
 def parse_the_args(arguments):
@@ -116,7 +117,7 @@ def get_database_function(name, type_var):
 
 
 def init_log_file(name, views, cl_type, log, debug, label,
-                  result_directory, add_noise, noise_std):
+                  result_directory, add_noise, noise_std, args):
     r"""Used to init the directory where the preds will be stored and the log file.
 
     First this function will check if the result directory already exists (only one per minute is allowed).
@@ -172,7 +173,7 @@ def init_log_file(name, views, cl_type, log, debug, label,
                         filemode='w')
     if log:
         logging.getLogger().addHandler(logging.StreamHandler())
-
+    save_config(result_directory, args)
     return result_directory
 
 
-- 
GitLab