Commit 6483c0dc authored by Léo Bouscarrat's avatar Léo Bouscarrat
Browse files

Correction on random extraction

parent 138660cb
......@@ -80,12 +80,14 @@ class Trainer(object):
OmpForestBinaryClassifier, OmpForestMulticlassClassifier.
self._logger.debug('Training model using train set...')
self._begin_time = time.time()
if type(model) in [RandomForestRegressor, RandomForestClassifier]:
if extracted_forest_size is not None:
model.estimators_ = np.random.choice(model.estimators_, extracted_forest_size)
estimators_index = np.arange(1000)
choosen_estimators = estimators_index[:extracted_forest_size]
model.estimators_ = np.array(model.estimators_)[choosen_estimators]
......@@ -10,6 +10,7 @@ from bolsonaro.error_handling.logger_factory import LoggerFactory
from dotenv import find_dotenv, load_dotenv
import argparse
import copy
import json
import pathlib
import random
......@@ -163,7 +164,7 @@ def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_siz, experiment_id)
model =, model_parameters, library=library)
model = pretrained_estimator
model = copy.deepcopy(pretrained_estimator), experiment_id)
trainer.init(model, subsets_used=parameters['subsets_used'])
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment