Commit 6483c0dc authored by Léo Bouscarrat's avatar Léo Bouscarrat
Browse files

Correction on random extraction

parent 138660cb
......@@ -80,12 +80,14 @@ class Trainer(object):
OmpForestBinaryClassifier, OmpForestMulticlassClassifier.
:return:
"""
self._logger.debug('Training model using train set...')
self._begin_time = time.time()
if type(model) in [RandomForestRegressor, RandomForestClassifier]:
if extracted_forest_size is not None:
model.estimators_ = np.random.choice(model.estimators_, extracted_forest_size)
estimators_index = np.arange(1000)
np.random.shuffle(estimators_index)
choosen_estimators = estimators_index[:extracted_forest_size]
model.estimators_ = np.array(model.estimators_)[choosen_estimators]
else:
model.fit(
X=self._X_forest,
......
......@@ -10,6 +10,7 @@ from bolsonaro.error_handling.logger_factory import LoggerFactory
from dotenv import find_dotenv, load_dotenv
import argparse
import copy
import json
import pathlib
import random
......@@ -163,7 +164,7 @@ def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_siz
model_parameters.save(sub_models_dir, experiment_id)
model = ModelFactory.build(dataset.task, model_parameters, library=library)
else:
model = pretrained_estimator
model = copy.deepcopy(pretrained_estimator)
pretrained_model_parameters.save(sub_models_dir, experiment_id)
trainer.init(model, subsets_used=parameters['subsets_used'])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment