Skip to content
Snippets Groups Projects
Commit be5bc24a authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Add on resume mode for the experiment training (and set the overwrite of the...

Add on resume mode for the experiment training (and set the overwrite of the resulting model of the experiment optional)
parent 31724b30
No related branches found
No related tags found
1 merge request!12Resolve "integration-sota"
......@@ -58,6 +58,21 @@ def process_job(seed, parameters, experiment_id, hyperparameters):
for extracted_forest_size in parameters['extracted_forest_size']:
logger.info('extracted_forest_size={}'.format(extracted_forest_size))
sub_models_dir = models_dir + os.sep + 'extracted_forest_sizes' + os.sep + str(extracted_forest_size)
# Check if the result file already exists
already_exists = False
if os.path.isdir(sub_models_dir):
sub_models_dir_files = os.listdir(sub_models_dir)
for file_name in sub_models_dir_files:
if '.pickle' != os.path.splitext(file_name)[1]:
continue
else:
already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0
break
if already_exists:
logger.info(f'Extracted forest {extracted_forest_size} result already exists. Skipping...')
continue
pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True)
model_parameters = ModelParameters(
......@@ -80,24 +95,38 @@ def process_job(seed, parameters, experiment_id, hyperparameters):
forest_size = hyperparameters['n_estimators']
logger.info('Base forest training with fixed forest size of {}'.format(forest_size))
sub_models_dir = models_dir + os.sep + 'forest_size' + os.sep + str(forest_size)
pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True)
model_parameters = ModelParameters(
extracted_forest_size=forest_size,
normalize_D=parameters['normalize_D'],
subsets_used=parameters['subsets_used'],
normalize_weights=parameters['normalize_weights'],
seed=seed,
hyperparameters=hyperparameters,
extraction_strategy=parameters['extraction_strategy']
)
model_parameters.save(sub_models_dir, experiment_id)
model = ModelFactory.build(dataset.task, model_parameters)
trainer.init(model, subsets_used=parameters['subsets_used'])
trainer.train(model)
trainer.compute_results(model, sub_models_dir)
# Check if the result file already exists
already_exists = False
if os.path.isdir(sub_models_dir):
sub_models_dir_files = os.listdir(sub_models_dir)
for file_name in sub_models_dir_files:
if '.pickle' != os.path.splitext(file_name)[1]:
continue
else:
already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0
break
if already_exists:
logger.info('Base forest result already exists. Skipping...')
else:
pass
pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True)
model_parameters = ModelParameters(
extracted_forest_size=forest_size,
normalize_D=parameters['normalize_D'],
subsets_used=parameters['subsets_used'],
normalize_weights=parameters['normalize_weights'],
seed=seed,
hyperparameters=hyperparameters,
extraction_strategy=parameters['extraction_strategy']
)
model_parameters.save(sub_models_dir, experiment_id)
model = ModelFactory.build(dataset.task, model_parameters)
trainer.init(model, subsets_used=parameters['subsets_used'])
trainer.train(model)
trainer.compute_results(model, sub_models_dir)
logger.info('Training done')
"""
......@@ -220,7 +249,7 @@ if __name__ == "__main__":
if args.experiment_id:
experiment_id = args.experiment_id
shutil.rmtree(os.path.join(parameters['models_dir'], str(experiment_id)), ignore_errors=True)
#shutil.rmtree(os.path.join(parameters['models_dir'], str(experiment_id)), ignore_errors=True)
else:
# Resolve the next experiment id number (last id + 1)
experiment_id = resolve_experiment_id(parameters['models_dir'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment