From be5bc24a2e46a5f5da7fbbf10ae13fa9f080f0cf Mon Sep 17 00:00:00 2001 From: Charly Lamothe <charly.lamothe@univ-amu.fr> Date: Fri, 6 Mar 2020 02:07:41 +0100 Subject: [PATCH] Add on resume mode for the experiment training (and set the overwrite of the resulting model of the experiment optional) --- code/train.py | 67 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 48 insertions(+), 19 deletions(-) diff --git a/code/train.py b/code/train.py index 0ca2b47..7df811d 100644 --- a/code/train.py +++ b/code/train.py @@ -58,6 +58,21 @@ def process_job(seed, parameters, experiment_id, hyperparameters): for extracted_forest_size in parameters['extracted_forest_size']: logger.info('extracted_forest_size={}'.format(extracted_forest_size)) sub_models_dir = models_dir + os.sep + 'extracted_forest_sizes' + os.sep + str(extracted_forest_size) + + # Check if the result file already exists + already_exists = False + if os.path.isdir(sub_models_dir): + sub_models_dir_files = os.listdir(sub_models_dir) + for file_name in sub_models_dir_files: + if '.pickle' != os.path.splitext(file_name)[1]: + continue + else: + already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0 + break + if already_exists: + logger.info(f'Extracted forest {extracted_forest_size} result already exists. Skipping...') + continue + pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True) model_parameters = ModelParameters( @@ -80,24 +95,38 @@ def process_job(seed, parameters, experiment_id, hyperparameters): forest_size = hyperparameters['n_estimators'] logger.info('Base forest training with fixed forest size of {}'.format(forest_size)) sub_models_dir = models_dir + os.sep + 'forest_size' + os.sep + str(forest_size) - pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True) - - model_parameters = ModelParameters( - extracted_forest_size=forest_size, - normalize_D=parameters['normalize_D'], - subsets_used=parameters['subsets_used'], - normalize_weights=parameters['normalize_weights'], - seed=seed, - hyperparameters=hyperparameters, - extraction_strategy=parameters['extraction_strategy'] - ) - model_parameters.save(sub_models_dir, experiment_id) - - model = ModelFactory.build(dataset.task, model_parameters) - - trainer.init(model, subsets_used=parameters['subsets_used']) - trainer.train(model) - trainer.compute_results(model, sub_models_dir) + + # Check if the result file already exists + already_exists = False + if os.path.isdir(sub_models_dir): + sub_models_dir_files = os.listdir(sub_models_dir) + for file_name in sub_models_dir_files: + if '.pickle' != os.path.splitext(file_name)[1]: + continue + else: + already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0 + break + if already_exists: + logger.info('Base forest result already exists. Skipping...') + else: + pass + pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True) + model_parameters = ModelParameters( + extracted_forest_size=forest_size, + normalize_D=parameters['normalize_D'], + subsets_used=parameters['subsets_used'], + normalize_weights=parameters['normalize_weights'], + seed=seed, + hyperparameters=hyperparameters, + extraction_strategy=parameters['extraction_strategy'] + ) + model_parameters.save(sub_models_dir, experiment_id) + + model = ModelFactory.build(dataset.task, model_parameters) + + trainer.init(model, subsets_used=parameters['subsets_used']) + trainer.train(model) + trainer.compute_results(model, sub_models_dir) logger.info('Training done') """ @@ -220,7 +249,7 @@ if __name__ == "__main__": if args.experiment_id: experiment_id = args.experiment_id - shutil.rmtree(os.path.join(parameters['models_dir'], str(experiment_id)), ignore_errors=True) + #shutil.rmtree(os.path.join(parameters['models_dir'], str(experiment_id)), ignore_errors=True) else: # Resolve the next experiment id number (last id + 1) experiment_id = resolve_experiment_id(parameters['models_dir']) -- GitLab