Commit 6483c0dc authored by Léo Bouscarrat's avatar Léo Bouscarrat
Browse files

Correction on random extraction

parent 138660cb
...@@ -80,12 +80,14 @@ class Trainer(object): ...@@ -80,12 +80,14 @@ class Trainer(object):
OmpForestBinaryClassifier, OmpForestMulticlassClassifier. OmpForestBinaryClassifier, OmpForestMulticlassClassifier.
:return: :return:
""" """
self._logger.debug('Training model using train set...') self._logger.debug('Training model using train set...')
self._begin_time = time.time() self._begin_time = time.time()
if type(model) in [RandomForestRegressor, RandomForestClassifier]: if type(model) in [RandomForestRegressor, RandomForestClassifier]:
if extracted_forest_size is not None: if extracted_forest_size is not None:
model.estimators_ = np.random.choice(model.estimators_, extracted_forest_size) estimators_index = np.arange(1000)
np.random.shuffle(estimators_index)
choosen_estimators = estimators_index[:extracted_forest_size]
model.estimators_ = np.array(model.estimators_)[choosen_estimators]
else: else:
model.fit( model.fit(
X=self._X_forest, X=self._X_forest,
......
...@@ -10,6 +10,7 @@ from bolsonaro.error_handling.logger_factory import LoggerFactory ...@@ -10,6 +10,7 @@ from bolsonaro.error_handling.logger_factory import LoggerFactory
from dotenv import find_dotenv, load_dotenv from dotenv import find_dotenv, load_dotenv
import argparse import argparse
import copy
import json import json
import pathlib import pathlib
import random import random
...@@ -163,7 +164,7 @@ def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_siz ...@@ -163,7 +164,7 @@ def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_siz
model_parameters.save(sub_models_dir, experiment_id) model_parameters.save(sub_models_dir, experiment_id)
model = ModelFactory.build(dataset.task, model_parameters, library=library) model = ModelFactory.build(dataset.task, model_parameters, library=library)
else: else:
model = pretrained_estimator model = copy.deepcopy(pretrained_estimator)
pretrained_model_parameters.save(sub_models_dir, experiment_id) pretrained_model_parameters.save(sub_models_dir, experiment_id)
trainer.init(model, subsets_used=parameters['subsets_used']) trainer.init(model, subsets_used=parameters['subsets_used'])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment