Skip to content
Snippets Groups Projects
Commit 4d4c0848 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Merge from master

parent 1e4c3afe
No related branches found
No related tags found
No related merge requests found
...@@ -19,7 +19,7 @@ class KMeansForestRegressor(BaseEstimator, metaclass=ABCMeta): ...@@ -19,7 +19,7 @@ class KMeansForestRegressor(BaseEstimator, metaclass=ABCMeta):
def __init__(self, models_parameters, score_metric=mean_squared_error): def __init__(self, models_parameters, score_metric=mean_squared_error):
self._models_parameters = models_parameters self._models_parameters = models_parameters
self._estimator = RandomForestRegressor(**self._models_parameters.hyperparameters, self._estimator = RandomForestRegressor(**self._models_parameters.hyperparameters,
random_state=self._models_parameters.seed, n_jobs=-1) random_state=self._models_parameters.seed, n_jobs=2)
self._extracted_forest_size = self._models_parameters.extracted_forest_size self._extracted_forest_size = self._models_parameters.extracted_forest_size
self._score_metric = score_metric self._score_metric = score_metric
self._selected_trees = list() self._selected_trees = list()
...@@ -46,7 +46,7 @@ class KMeansForestRegressor(BaseEstimator, metaclass=ABCMeta): ...@@ -46,7 +46,7 @@ class KMeansForestRegressor(BaseEstimator, metaclass=ABCMeta):
# For each cluster select the best tree on the validation set # For each cluster select the best tree on the validation set
extracted_forest_sizes = list(range(self._extracted_forest_size)) extracted_forest_sizes = list(range(self._extracted_forest_size))
with tqdm_joblib(tqdm(total=self._extracted_forest_size, disable=True)) as prune_forest_job_pb: with tqdm_joblib(tqdm(total=self._extracted_forest_size, disable=True)) as prune_forest_job_pb:
pruned_forest = Parallel(n_jobs=-1)(delayed(self._prune_forest_job)(prune_forest_job_pb, pruned_forest = Parallel(n_jobs=2)(delayed(self._prune_forest_job)(prune_forest_job_pb,
extracted_forest_sizes[i], labels, X_val, y_val, self._score_metric) extracted_forest_sizes[i], labels, X_val, y_val, self._score_metric)
for i in range(self._extracted_forest_size)) for i in range(self._extracted_forest_size))
...@@ -56,7 +56,7 @@ class KMeansForestRegressor(BaseEstimator, metaclass=ABCMeta): ...@@ -56,7 +56,7 @@ class KMeansForestRegressor(BaseEstimator, metaclass=ABCMeta):
def _prune_forest_job(self, prune_forest_job_pb, c, labels, X_val, y_val, score_metric): def _prune_forest_job(self, prune_forest_job_pb, c, labels, X_val, y_val, score_metric):
index = np.where(labels == c)[0] index = np.where(labels == c)[0]
with tqdm_joblib(tqdm(total=len(index), disable=True)) as cluster_job_pb: with tqdm_joblib(tqdm(total=len(index), disable=True)) as cluster_job_pb:
cluster = Parallel(n_jobs=-1)(delayed(self._cluster_job)(cluster_job_pb, index[i], X_val, cluster = Parallel(n_jobs=2)(delayed(self._cluster_job)(cluster_job_pb, index[i], X_val,
y_val, score_metric) for i in range(len(index))) y_val, score_metric) for i in range(len(index)))
best_tree_index = np.argmax(cluster) best_tree_index = np.argmax(cluster)
prune_forest_job_pb.update() prune_forest_job_pb.update()
......
...@@ -283,6 +283,8 @@ if __name__ == "__main__": ...@@ -283,6 +283,8 @@ if __name__ == "__main__":
parameters['extracted_forest_size_samples'] + 1, parameters['extracted_forest_size_samples'] + 1,
endpoint=True)[1:]).astype(np.int)).tolist() endpoint=True)[1:]).astype(np.int)).tolist()
logger.info(f"extracted forest sizes: {parameters['extracted_forest_size']}")
if parameters['seeds'] != None and parameters['random_seed_number'] > 1: if parameters['seeds'] != None and parameters['random_seed_number'] > 1:
logger.warning('seeds and random_seed_number parameters are both specified. Seeds will be used.') logger.warning('seeds and random_seed_number parameters are both specified. Seeds will be used.')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment