diff --git a/multiview_platform/mono_multi_view_classifiers/exec_classif.py b/multiview_platform/mono_multi_view_classifiers/exec_classif.py index 88ae2c8717d873d46a2ceb85f1144abcc0adb5c5..5a91fb0dc9d47607a837f7ed3184f9a3a41126e6 100644 --- a/multiview_platform/mono_multi_view_classifiers/exec_classif.py +++ b/multiview_platform/mono_multi_view_classifiers/exec_classif.py @@ -807,7 +807,7 @@ def exec_classif(arguments): directory = execution.init_log_file(dataset_name, args["Base"]["views"], args["Classification"]["type"], args["Base"]["log"], args["Base"]["debug"], args["Base"]["label"], - args["Base"]["res_dir"], args["Base"]["add_noise"], noise_std) + args["Base"]["res_dir"], args["Base"]["add_noise"], noise_std, args) random_state = execution.init_random_state(args["Base"]["random_state"], directory) stats_iter_random_states = execution.init_stats_iter_random_states(stats_iter, random_state) diff --git a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py index e8437154145c38b3b6e0a8d82224e31c7d569eb7..484b6657dac9858231677e1900ca950f927128f6 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview_classifiers/weighted_linear_early_fusion.py @@ -31,7 +31,6 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier): def __init__(self, random_state=None, view_weights=None, monoview_classifier_name="decision_tree", monoview_classifier_config={}): - print(type(view_weights), view_weights) super(WeightedLinearEarlyFusion, self).__init__(random_state=random_state) self.view_weights = view_weights self.monoview_classifier_name = monoview_classifier_name @@ -84,12 +83,10 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier): example_indices, self.view_indices = get_examples_views_indices(dataset, example_indices, view_indices) - print(type(self.view_weights)) if self.view_weights is None: self.view_weights = np.ones(len(self.view_indices), dtype=float) else: self.view_weights = np.array(self.view_weights) - print(self.view_weights) self.view_weights /= float(np.sum(self.view_weights)) X = self.hdf5_to_monoview(dataset, example_indices) diff --git a/multiview_platform/mono_multi_view_classifiers/utils/configuration.py b/multiview_platform/mono_multi_view_classifiers/utils/configuration.py index f297dcf09deebab08b29573a45344fbd7e40a822..0f044c61ad0a0ef83083abb1d9a603ff57d4eb08 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/configuration.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/configuration.py @@ -1,4 +1,5 @@ import yaml +import os def get_the_args(path_to_config_file="../config_files/config.yml"): @@ -18,3 +19,11 @@ def get_the_args(path_to_config_file="../config_files/config.yml"): with open(path_to_config_file, 'r') as stream: yaml_config = yaml.safe_load(stream) return yaml_config + + +def save_config(directory, arguments): + """ + Saves the config file in the result directory. + """ + with open(os.path.join(directory, "config_file.yml"), "w") as stream: + yaml.dump(arguments, stream) \ No newline at end of file diff --git a/multiview_platform/mono_multi_view_classifiers/utils/execution.py b/multiview_platform/mono_multi_view_classifiers/utils/execution.py index 7bc2d155607928dd4df9eba428ab4cdf23c60cdc..a91684796716772ec25bd11d74505cca8b514bf7 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/execution.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/execution.py @@ -8,6 +8,7 @@ import numpy as np import sklearn from . import get_multiview_db as DB +from ..utils.configuration import save_config def parse_the_args(arguments): @@ -116,7 +117,7 @@ def get_database_function(name, type_var): def init_log_file(name, views, cl_type, log, debug, label, - result_directory, add_noise, noise_std): + result_directory, add_noise, noise_std, args): r"""Used to init the directory where the preds will be stored and the log file. First this function will check if the result directory already exists (only one per minute is allowed). @@ -172,7 +173,7 @@ def init_log_file(name, views, cl_type, log, debug, label, filemode='w') if log: logging.getLogger().addHandler(logging.StreamHandler()) - + save_config(result_directory, args) return result_directory