From be5bc24a2e46a5f5da7fbbf10ae13fa9f080f0cf Mon Sep 17 00:00:00 2001
From: Charly Lamothe <charly.lamothe@univ-amu.fr>
Date: Fri, 6 Mar 2020 02:07:41 +0100
Subject: [PATCH] Add on resume mode for the experiment training (and set the
 overwrite of the resulting model of the experiment optional)

---
 code/train.py | 67 ++++++++++++++++++++++++++++++++++++---------------
 1 file changed, 48 insertions(+), 19 deletions(-)

diff --git a/code/train.py b/code/train.py
index 0ca2b47..7df811d 100644
--- a/code/train.py
+++ b/code/train.py
@@ -58,6 +58,21 @@ def process_job(seed, parameters, experiment_id, hyperparameters):
         for extracted_forest_size in parameters['extracted_forest_size']:
             logger.info('extracted_forest_size={}'.format(extracted_forest_size))
             sub_models_dir = models_dir + os.sep + 'extracted_forest_sizes' + os.sep + str(extracted_forest_size)
+
+            # Check if the result file already exists
+            already_exists = False
+            if os.path.isdir(sub_models_dir):
+                sub_models_dir_files = os.listdir(sub_models_dir)
+                for file_name in sub_models_dir_files:
+                    if '.pickle' != os.path.splitext(file_name)[1]:
+                        continue
+                    else:
+                        already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0
+                        break
+            if already_exists:
+                logger.info(f'Extracted forest {extracted_forest_size} result already exists. Skipping...')
+                continue
+
             pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True)
 
             model_parameters = ModelParameters(
@@ -80,24 +95,38 @@ def process_job(seed, parameters, experiment_id, hyperparameters):
         forest_size = hyperparameters['n_estimators']
         logger.info('Base forest training with fixed forest size of {}'.format(forest_size))
         sub_models_dir = models_dir + os.sep + 'forest_size' + os.sep + str(forest_size)
-        pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True)
-
-        model_parameters = ModelParameters(
-            extracted_forest_size=forest_size,
-            normalize_D=parameters['normalize_D'],
-            subsets_used=parameters['subsets_used'],
-            normalize_weights=parameters['normalize_weights'],
-            seed=seed,
-            hyperparameters=hyperparameters,
-            extraction_strategy=parameters['extraction_strategy']
-        )
-        model_parameters.save(sub_models_dir, experiment_id)
-
-        model = ModelFactory.build(dataset.task, model_parameters)
-
-        trainer.init(model, subsets_used=parameters['subsets_used'])
-        trainer.train(model)
-        trainer.compute_results(model, sub_models_dir)
+
+        # Check if the result file already exists
+        already_exists = False
+        if os.path.isdir(sub_models_dir):
+            sub_models_dir_files = os.listdir(sub_models_dir)
+            for file_name in sub_models_dir_files:
+                if '.pickle' != os.path.splitext(file_name)[1]:
+                    continue
+                else:
+                    already_exists = os.path.getsize(os.path.join(sub_models_dir, file_name)) > 0
+                    break
+        if already_exists:
+            logger.info('Base forest result already exists. Skipping...')
+        else:
+            pass
+            pathlib.Path(sub_models_dir).mkdir(parents=True, exist_ok=True)
+            model_parameters = ModelParameters(
+                extracted_forest_size=forest_size,
+                normalize_D=parameters['normalize_D'],
+                subsets_used=parameters['subsets_used'],
+                normalize_weights=parameters['normalize_weights'],
+                seed=seed,
+                hyperparameters=hyperparameters,
+                extraction_strategy=parameters['extraction_strategy']
+            )
+            model_parameters.save(sub_models_dir, experiment_id)
+
+            model = ModelFactory.build(dataset.task, model_parameters)
+
+            trainer.init(model, subsets_used=parameters['subsets_used'])
+            trainer.train(model)
+            trainer.compute_results(model, sub_models_dir)
     logger.info('Training done')
 
 """
@@ -220,7 +249,7 @@ if __name__ == "__main__":
 
     if args.experiment_id:
         experiment_id = args.experiment_id
-        shutil.rmtree(os.path.join(parameters['models_dir'], str(experiment_id)), ignore_errors=True)
+        #shutil.rmtree(os.path.join(parameters['models_dir'], str(experiment_id)), ignore_errors=True)
     else:
         # Resolve the next experiment id number (last id + 1)
         experiment_id = resolve_experiment_id(parameters['models_dir'])
-- 
GitLab