From 6483c0dcf19a50cae810974ec537f88ac2f6fda6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Bouscarrat?= <leo.bouscarrat@euranova.eu> Date: Fri, 6 Mar 2020 18:27:32 +0100 Subject: [PATCH] Correction on random extraction --- code/bolsonaro/trainer.py | 6 ++++-- code/train.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py index ebcfe80..6fcf0af 100644 --- a/code/bolsonaro/trainer.py +++ b/code/bolsonaro/trainer.py @@ -80,12 +80,14 @@ class Trainer(object): OmpForestBinaryClassifier, OmpForestMulticlassClassifier. :return: """ - self._logger.debug('Training model using train set...') self._begin_time = time.time() if type(model) in [RandomForestRegressor, RandomForestClassifier]: if extracted_forest_size is not None: - model.estimators_ = np.random.choice(model.estimators_, extracted_forest_size) + estimators_index = np.arange(1000) + np.random.shuffle(estimators_index) + choosen_estimators = estimators_index[:extracted_forest_size] + model.estimators_ = np.array(model.estimators_)[choosen_estimators] else: model.fit( X=self._X_forest, diff --git a/code/train.py b/code/train.py index e70902b..8e48e14 100644 --- a/code/train.py +++ b/code/train.py @@ -10,6 +10,7 @@ from bolsonaro.error_handling.logger_factory import LoggerFactory from dotenv import find_dotenv, load_dotenv import argparse +import copy import json import pathlib import random @@ -163,7 +164,7 @@ def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_siz model_parameters.save(sub_models_dir, experiment_id) model = ModelFactory.build(dataset.task, model_parameters, library=library) else: - model = pretrained_estimator + model = copy.deepcopy(pretrained_estimator) pretrained_model_parameters.save(sub_models_dir, experiment_id) trainer.init(model, subsets_used=parameters['subsets_used']) -- GitLab