From 6483c0dcf19a50cae810974ec537f88ac2f6fda6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?L=C3=A9o=20Bouscarrat?= <leo.bouscarrat@euranova.eu>
Date: Fri, 6 Mar 2020 18:27:32 +0100
Subject: [PATCH] Correction on random extraction

---
 code/bolsonaro/trainer.py | 6 ++++--
 code/train.py             | 3 ++-
 2 files changed, 6 insertions(+), 3 deletions(-)

diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py
index ebcfe80..6fcf0af 100644
--- a/code/bolsonaro/trainer.py
+++ b/code/bolsonaro/trainer.py
@@ -80,12 +80,14 @@ class Trainer(object):
             OmpForestBinaryClassifier, OmpForestMulticlassClassifier.
         :return:
         """
-
         self._logger.debug('Training model using train set...')
         self._begin_time = time.time()
         if type(model) in [RandomForestRegressor, RandomForestClassifier]:
             if extracted_forest_size is not None:
-                model.estimators_ = np.random.choice(model.estimators_, extracted_forest_size)
+                estimators_index = np.arange(1000)
+                np.random.shuffle(estimators_index)
+                choosen_estimators = estimators_index[:extracted_forest_size]
+                model.estimators_ = np.array(model.estimators_)[choosen_estimators]
             else:
                 model.fit(
                     X=self._X_forest,
diff --git a/code/train.py b/code/train.py
index e70902b..8e48e14 100644
--- a/code/train.py
+++ b/code/train.py
@@ -10,6 +10,7 @@ from bolsonaro.error_handling.logger_factory import LoggerFactory
 
 from dotenv import find_dotenv, load_dotenv
 import argparse
+import copy
 import json
 import pathlib
 import random
@@ -163,7 +164,7 @@ def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_siz
         model_parameters.save(sub_models_dir, experiment_id)
         model = ModelFactory.build(dataset.task, model_parameters, library=library)
     else:
-        model = pretrained_estimator
+        model = copy.deepcopy(pretrained_estimator)
         pretrained_model_parameters.save(sub_models_dir, experiment_id)
 
     trainer.init(model, subsets_used=parameters['subsets_used'])
-- 
GitLab