diff --git a/code/bolsonaro/models/model_factory.py b/code/bolsonaro/models/model_factory.py
index 6f1836278675662ac7ef57f8bf98cc4c8284dc26..56e267fb8208bd733a428d43f9b6bfe6d19b16da 100644
--- a/code/bolsonaro/models/model_factory.py
+++ b/code/bolsonaro/models/model_factory.py
@@ -19,7 +19,7 @@ class ModelFactory(object):
             raise ValueError("Unsupported task '{}'".format(task))
 
         if task == Task.BINARYCLASSIFICATION:
-            if model_parameters.extraction_strategy in ['omp', 'omp_distillation']:
+            if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']:
                 return OmpForestBinaryClassifier(model_parameters)
             elif model_parameters.extraction_strategy == 'random':
                 return RandomForestClassifier(**model_parameters.hyperparameters,
@@ -36,7 +36,7 @@ class ModelFactory(object):
             else:
                 raise ValueError('Invalid extraction strategy')
         elif task == Task.REGRESSION:
-            if model_parameters.extraction_strategy in ['omp', 'omp_distillation']:
+            if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']:
                 return OmpForestRegressor(model_parameters)
             elif model_parameters.extraction_strategy == 'random':
                 return RandomForestRegressor(**model_parameters.hyperparameters,
@@ -53,7 +53,7 @@ class ModelFactory(object):
             else:
                 raise ValueError('Invalid extraction strategy')
         elif task == Task.MULTICLASSIFICATION:
-            if model_parameters.extraction_strategy in ['omp', 'omp_distillation']:
+            if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']:
                 return OmpForestMulticlassClassifier(model_parameters)
             elif model_parameters.extraction_strategy == 'random':
                 return RandomForestClassifier(**model_parameters.hyperparameters,
diff --git a/code/train.py b/code/train.py
index 10dbf7354837cab803202a8307c671f0def0f274..457e1c405d203e79c4b58f366ef6adfd596d8948 100644
--- a/code/train.py
+++ b/code/train.py
@@ -76,14 +76,7 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
         pretrained_estimator = None
         pretrained_model_parameters = None
 
-    if parameters['extraction_strategy'] != 'none':
-        with tqdm_joblib(tqdm(total=len(parameters['extracted_forest_size']), disable=not verbose)) as extracted_forest_size_job_pb:
-            Parallel(n_jobs=-1)(delayed(extracted_forest_size_job)(extracted_forest_size_job_pb, parameters['extracted_forest_size'][i],
-                models_dir, seed, parameters, dataset, hyperparameters, experiment_id, trainer,
-                pretrained_estimator=pretrained_estimator, pretrained_model_parameters=pretrained_model_parameters,
-                use_distillation=parameters['extraction_strategy'] == 'omp_distillation')
-                for i in range(len(parameters['extracted_forest_size'])))
-    else:
+    if parameters['extraction_strategy'] == 'none':
         forest_size = hyperparameters['n_estimators']
         logger.info('Base forest training with fixed forest size of {}'.format(forest_size))
         sub_models_dir = models_dir + os.sep + 'forest_size' + os.sep + str(forest_size)
@@ -118,6 +111,32 @@ def seed_job(seed_job_pb, seed, parameters, experiment_id, hyperparameters, verb
             trainer.init(model, subsets_used=parameters['subsets_used'])
             trainer.train(model)
             trainer.compute_results(model, sub_models_dir)
+    elif parameters['extraction_strategy'] == 'omp_nn':
+        forest_size = hyperparameters['n_estimators']
+        model_parameters = ModelParameters(
+            extracted_forest_size=forest_size,
+            normalize_D=parameters['normalize_D'],
+            subsets_used=parameters['subsets_used'],
+            normalize_weights=parameters['normalize_weights'],
+            seed=seed,
+            hyperparameters=hyperparameters,
+            extraction_strategy=parameters['extraction_strategy']
+        )
+        model_parameters.save(sub_models_dir, experiment_id)
+
+        model = ModelFactory.build(dataset.task, model_parameters)
+
+        trainer.init(model, subsets_used=parameters['subsets_used'])
+        trainer.train(model)
+        trainer.compute_results(model, sub_models_dir)
+    else:
+        with tqdm_joblib(tqdm(total=len(parameters['extracted_forest_size']), disable=not verbose)) as extracted_forest_size_job_pb:
+            Parallel(n_jobs=-1)(delayed(extracted_forest_size_job)(extracted_forest_size_job_pb, parameters['extracted_forest_size'][i],
+                models_dir, seed, parameters, dataset, hyperparameters, experiment_id, trainer,
+                pretrained_estimator=pretrained_estimator, pretrained_model_parameters=pretrained_model_parameters,
+                use_distillation=parameters['extraction_strategy'] == 'omp_distillation')
+                for i in range(len(parameters['extracted_forest_size'])))
+
     logger.info(f'Training done for seed {seed_str}')
     seed_job_pb.update(1)
 
@@ -232,7 +251,7 @@ if __name__ == "__main__":
     parser.add_argument('--skip_best_hyperparams', action='store_true', default=DEFAULT_SKIP_BEST_HYPERPARAMS, help='Do not use the best hyperparameters if there exist.')
     parser.add_argument('--save_experiment_configuration', nargs='+', default=None, help='Save the experiment parameters specified in the command line in a file. Args: {{stage_num}} {{name}}')
     parser.add_argument('--job_number', nargs='?', type=int, default=DEFAULT_JOB_NUMBER, help='Specify the number of job used during the parallelisation across seeds.')
-    parser.add_argument('--extraction_strategy', nargs='?', type=str, default=DEFAULT_EXTRACTION_STRATEGY, help='Specify the strategy to apply to extract the trees from the forest. Either omp, random, none, similarity_similarities, similarity_predictions, kmeans, ensemble.')
+    parser.add_argument('--extraction_strategy', nargs='?', type=str, default=DEFAULT_EXTRACTION_STRATEGY, help='Specify the strategy to apply to extract the trees from the forest. Either omp, omp_nn, random, none, similarity_similarities, similarity_predictions, kmeans, ensemble.')
     parser.add_argument('--overwrite', action='store_true', default=DEFAULT_OVERWRITE, help='Overwrite the experiment id')
     args = parser.parse_args()
 
@@ -243,7 +262,7 @@ if __name__ == "__main__":
     else:
         parameters = args.__dict__
 
-    if parameters['extraction_strategy'] not in ['omp', 'omp_distillation', 'random', 'none', 'similarity_similarities', 'similarity_predictions', 'kmeans', 'ensemble']:
+    if parameters['extraction_strategy'] not in ['omp', 'omp_nn', 'omp_distillation', 'random', 'none', 'similarity_similarities', 'similarity_predictions', 'kmeans', 'ensemble']:
         raise ValueError('Specified extraction strategy {} is not supported.'.format(parameters['extraction_strategy']))
 
     pathlib.Path(parameters['models_dir']).mkdir(parents=True, exist_ok=True)