Skip to content
Snippets Groups Projects
Commit 6483c0dc authored by Léo Bouscarrat's avatar Léo Bouscarrat
Browse files

Correction on random extraction

parent 138660cb
No related branches found
No related tags found
2 merge requests!20Resolve "integration-sota",!19WIP: Resolve "Adding new datasets"
...@@ -80,12 +80,14 @@ class Trainer(object): ...@@ -80,12 +80,14 @@ class Trainer(object):
OmpForestBinaryClassifier, OmpForestMulticlassClassifier. OmpForestBinaryClassifier, OmpForestMulticlassClassifier.
:return: :return:
""" """
self._logger.debug('Training model using train set...') self._logger.debug('Training model using train set...')
self._begin_time = time.time() self._begin_time = time.time()
if type(model) in [RandomForestRegressor, RandomForestClassifier]: if type(model) in [RandomForestRegressor, RandomForestClassifier]:
if extracted_forest_size is not None: if extracted_forest_size is not None:
model.estimators_ = np.random.choice(model.estimators_, extracted_forest_size) estimators_index = np.arange(1000)
np.random.shuffle(estimators_index)
choosen_estimators = estimators_index[:extracted_forest_size]
model.estimators_ = np.array(model.estimators_)[choosen_estimators]
else: else:
model.fit( model.fit(
X=self._X_forest, X=self._X_forest,
......
...@@ -10,6 +10,7 @@ from bolsonaro.error_handling.logger_factory import LoggerFactory ...@@ -10,6 +10,7 @@ from bolsonaro.error_handling.logger_factory import LoggerFactory
from dotenv import find_dotenv, load_dotenv from dotenv import find_dotenv, load_dotenv
import argparse import argparse
import copy
import json import json
import pathlib import pathlib
import random import random
...@@ -163,7 +164,7 @@ def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_siz ...@@ -163,7 +164,7 @@ def extracted_forest_size_job(extracted_forest_size_job_pb, extracted_forest_siz
model_parameters.save(sub_models_dir, experiment_id) model_parameters.save(sub_models_dir, experiment_id)
model = ModelFactory.build(dataset.task, model_parameters, library=library) model = ModelFactory.build(dataset.task, model_parameters, library=library)
else: else:
model = pretrained_estimator model = copy.deepcopy(pretrained_estimator)
pretrained_model_parameters.save(sub_models_dir, experiment_id) pretrained_model_parameters.save(sub_models_dir, experiment_id)
trainer.init(model, subsets_used=parameters['subsets_used']) trainer.init(model, subsets_used=parameters['subsets_used'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment