diff --git a/code/train.py b/code/train.py index 0ca2b47ae7c49726d2de75091f4bc80f883ffb81..7df811d20c1413004413a1a2d9c151fea8a1ad8d 100644 --- a/code/train.py +++ b/code/train.py @@ -58,6 +58,21 @@ def process_job(seed, parameters, experiment_id, hyperparameters): for extracted_forest_size in parameters['extracted_forest_size']: logger.info('extracted_forest_size={}'.format(extracted_forest_size)) sub_models_dir = models_dir + os.sep + 'extracted_forest_sizes' + os.sep + str(extracted_forest_size) + + # Check if the result file already exists + already_exists = False + if os.path.isdir(sub_models_dir): + sub_models_dir_files = os.listdir(sub_models_dir) + for file_name in sub_models_dir_files: + if '.pickle' != os.path.splitext(file_name)[1]: + continue + else: + already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0 + break + if already_exists: + logger.info(f'Extracted forest {extracted_forest_size} result already exists. Skipping...') + continue + pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True) model_parameters = ModelParameters( @@ -80,24 +95,38 @@ def process_job(seed, parameters, experiment_id, hyperparameters): forest_size = hyperparameters['n_estimators'] logger.info('Base forest training with fixed forest size of {}'.format(forest_size)) sub_models_dir = models_dir + os.sep + 'forest_size' + os.sep + str(forest_size) - pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True) - - model_parameters = ModelParameters( - extracted_forest_size=forest_size, - normalize_D=parameters['normalize_D'], - subsets_used=parameters['subsets_used'], - normalize_weights=parameters['normalize_weights'], - seed=seed, - hyperparameters=hyperparameters, - extraction_strategy=parameters['extraction_strategy'] - ) - model_parameters.save(sub_models_dir, experiment_id) - - model = ModelFactory.build(dataset.task, model_parameters) - - trainer.init(model, subsets_used=parameters['subsets_used']) - trainer.train(model) - trainer.compute_results(model, sub_models_dir) + + # Check if the result file already exists + already_exists = False + if os.path.isdir(sub_models_dir): + sub_models_dir_files = os.listdir(sub_models_dir) + for file_name in sub_models_dir_files: + if '.pickle' != os.path.splitext(file_name)[1]: + continue + else: + already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0 + break + if already_exists: + logger.info('Base forest result already exists. Skipping...') + else: + pass + pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True) + model_parameters = ModelParameters( + extracted_forest_size=forest_size, + normalize_D=parameters['normalize_D'], + subsets_used=parameters['subsets_used'], + normalize_weights=parameters['normalize_weights'], + seed=seed, + hyperparameters=hyperparameters, + extraction_strategy=parameters['extraction_strategy'] + ) + model_parameters.save(sub_models_dir, experiment_id) + + model = ModelFactory.build(dataset.task, model_parameters) + + trainer.init(model, subsets_used=parameters['subsets_used']) + trainer.train(model) + trainer.compute_results(model, sub_models_dir) logger.info('Training done') """ @@ -220,7 +249,7 @@ if __name__ == "__main__": if args.experiment_id: experiment_id = args.experiment_id - shutil.rmtree(os.path.join(parameters['models_dir'], str(experiment_id)), ignore_errors=True) + #shutil.rmtree(os.path.join(parameters['models_dir'], str(experiment_id)), ignore_errors=True) else: # Resolve the next experiment id number (last id + 1) experiment_id = resolve_experiment_id(parameters['models_dir'])