Skip to content
Snippets Groups Projects
Commit 6483c0dc authored by Léo Bouscarrat's avatar Léo Bouscarrat
Browse files

Correction on random extraction

parent 138660cb
No related branches found
No related tags found
2 merge requests!20Resolve "integration-sota",!19WIP: Resolve "Adding new datasets"
......@@ -80,12 +80,14 @@ class Trainer(object):
OmpForestBinaryClassifier, OmpForestMulticlassClassifier.
:return:
"""
self._logger.debug('Training model using train set...')
self._begin_time = time.time()
if type(model) in [RandomForestRegressor, RandomForestClassifier]:
if extracted_forest_size is not None:
model.estimators_ = np.random.choice(model.estimators_, extracted_forest_size)
estimators_index = np.arange(1000)
np.random.shuffle(estimators_index)
choosen_estimators = estimators_index[:extracted_forest_size]
model.estimators_ = np.array(model.estimators_)[choosen_estimators]
else:
model.fit(
X=self._X_forest,
......
......@@ -10,6 +10,7 @@ from bolsonaro.error_handling.logger_factory import LoggerFactory
from dotenv import find_dotenv, load_dotenv
import argparse
import copy
import json
import pathlib
import random
......@@ -163,7 +164,7 @@ def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_siz
model_parameters.save(sub_models_dir, experiment_id)
model = ModelFactory.build(dataset.task, model_parameters, library=library)
else:
model = pretrained_estimator
model = copy.deepcopy(pretrained_estimator)
pretrained_model_parameters.save(sub_models_dir, experiment_id)
trainer.init(model, subsets_used=parameters['subsets_used'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment