Skip to content
Snippets Groups Projects
Commit 9dd0172b authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Removed prints and now saving config file in result directory

parent 112fff2e
Branches
Tags
No related merge requests found
......@@ -807,7 +807,7 @@ def exec_classif(arguments):
directory = execution.init_log_file(dataset_name, args["Base"]["views"], args["Classification"]["type"],
args["Base"]["log"], args["Base"]["debug"], args["Base"]["label"],
args["Base"]["res_dir"], args["Base"]["add_noise"], noise_std)
args["Base"]["res_dir"], args["Base"]["add_noise"], noise_std, args)
random_state = execution.init_random_state(args["Base"]["random_state"], directory)
stats_iter_random_states = execution.init_stats_iter_random_states(stats_iter,
random_state)
......
......@@ -31,7 +31,6 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier):
def __init__(self, random_state=None, view_weights=None,
monoview_classifier_name="decision_tree",
monoview_classifier_config={}):
print(type(view_weights), view_weights)
super(WeightedLinearEarlyFusion, self).__init__(random_state=random_state)
self.view_weights = view_weights
self.monoview_classifier_name = monoview_classifier_name
......@@ -84,12 +83,10 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier):
example_indices, self.view_indices = get_examples_views_indices(dataset,
example_indices,
view_indices)
print(type(self.view_weights))
if self.view_weights is None:
self.view_weights = np.ones(len(self.view_indices), dtype=float)
else:
self.view_weights = np.array(self.view_weights)
print(self.view_weights)
self.view_weights /= float(np.sum(self.view_weights))
X = self.hdf5_to_monoview(dataset, example_indices)
......
import yaml
import os
def get_the_args(path_to_config_file="../config_files/config.yml"):
......@@ -18,3 +19,11 @@ def get_the_args(path_to_config_file="../config_files/config.yml"):
with open(path_to_config_file, 'r') as stream:
yaml_config = yaml.safe_load(stream)
return yaml_config
def save_config(directory, arguments):
"""
Saves the config file in the result directory.
"""
with open(os.path.join(directory, "config_file.yml"), "w") as stream:
yaml.dump(arguments, stream)
\ No newline at end of file
......@@ -8,6 +8,7 @@ import numpy as np
import sklearn
from . import get_multiview_db as DB
from ..utils.configuration import save_config
def parse_the_args(arguments):
......@@ -116,7 +117,7 @@ def get_database_function(name, type_var):
def init_log_file(name, views, cl_type, log, debug, label,
result_directory, add_noise, noise_std):
result_directory, add_noise, noise_std, args):
r"""Used to init the directory where the preds will be stored and the log file.
First this function will check if the result directory already exists (only one per minute is allowed).
......@@ -172,7 +173,7 @@ def init_log_file(name, views, cl_type, log, debug, label,
filemode='w')
if log:
logging.getLogger().addHandler(logging.StreamHandler())
save_config(result_directory, args)
return result_directory
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment