Commit c04845a9 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Record selected trees for the 3 SOTA methods and random

parent 9d7ef0e7
......@@ -16,6 +16,7 @@ class EnsembleSelectionForestRegressor(BaseEstimator, metaclass=ABCMeta):
self._library = library
self._extracted_forest_size = self._models_parameters.extracted_forest_size
self._score_metric = score_metric
self._selected_trees = list()
@property
def models_parameters(self):
......@@ -25,6 +26,10 @@ class EnsembleSelectionForestRegressor(BaseEstimator, metaclass=ABCMeta):
def library(self):
return self._library
@property
def selected_trees(self):
return self._selected_trees
def fit(self, X_train, y_train, X_val, y_val):
scores_list = list()
for estimator in self._library:
......@@ -33,7 +38,7 @@ class EnsembleSelectionForestRegressor(BaseEstimator, metaclass=ABCMeta):
class_list = list(self._library)
m = np.argmax(np.asarray(scores_list))
self._ensemble_selected = [class_list[m]]
self._selected_trees = [class_list[m]]
temp_pred = class_list[m].predict(X_val)
del class_list[m]
for k in range(self._extracted_forest_size - 1):
......@@ -47,7 +52,7 @@ class EnsembleSelectionForestRegressor(BaseEstimator, metaclass=ABCMeta):
candidate_index = j
best_score = temp_score
temp_pred = np.delete(temp_pred, -1, 0)
self._ensemble_selected.append(class_list[candidate_index])
self._selected_trees.append(class_list[candidate_index])
temp_pred = np.vstack((temp_pred, class_list[candidate_index].predict(X_val)))
del class_list[candidate_index]
......@@ -57,7 +62,7 @@ class EnsembleSelectionForestRegressor(BaseEstimator, metaclass=ABCMeta):
def predict_base_estimator(self, X):
predictions = list()
for tree in self._ensemble_selected:
for tree in self._selected_trees:
predictions.append(tree.predict(X))
mean_predictions = np.mean(np.array(predictions), axis=0)
return mean_predictions
......
......@@ -22,11 +22,16 @@ class KMeansForestRegressor(BaseEstimator, metaclass=ABCMeta):
random_state=self._models_parameters.seed, n_jobs=-1)
self._extracted_forest_size = self._models_parameters.extracted_forest_size
self._score_metric = score_metric
self._selected_trees = list()
@property
def models_parameters(self):
return self._models_parameters
@property
def selected_trees(self):
return self._selected_trees
def fit(self, X_train, y_train, X_val, y_val):
self._estimator.fit(X_train, y_train)
......@@ -45,6 +50,7 @@ class KMeansForestRegressor(BaseEstimator, metaclass=ABCMeta):
extracted_forest_sizes[i], labels, X_val, y_val, self._score_metric)
for i in range(self._extracted_forest_size))
self._selected_trees = pruned_forest
self._estimator.estimators_ = pruned_forest
def _prune_forest_job(self, prune_forest_job_pb, c, labels, X_val, y_val, score_metric):
......
from bolsonaro import LOG_PATH
from bolsonaro.error_handling.logger_factory import LoggerFactory
from bolsonaro.utils import omp_premature_warning
from abc import abstractmethod, ABCMeta
import numpy as np
......@@ -14,6 +15,7 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
self._base_forest_estimator = base_forest_estimator
self._models_parameters = models_parameters
self._logger = LoggerFactory.create(LOG_PATH, __name__)
self._selected_trees = list()
@property
def models_parameters(self):
......@@ -145,7 +147,7 @@ class SingleOmpForest(OmpForest):
Make all the base tree predictions
:param X: a Forest
:return: a np.array of the predictions of the trees selected by OMP without applyong the weight
:return: a np.array of the predictions of the trees selected by OMP without applying the weight
"""
forest_predictions = np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_])
......
......@@ -17,18 +17,22 @@ class SimilarityForestRegressor(BaseEstimator, metaclass=ABCMeta):
random_state=self._models_parameters.seed, n_jobs=-1)
self._extracted_forest_size = self._models_parameters.extracted_forest_size
self._score_metric = score_metric
self._selected_trees = list()
@property
def models_parameters(self):
return self._models_parameters
@property
def selected_trees(self):
return self._selected_trees
def fit(self, X_train, y_train, X_val, y_val):
self._estimator.fit(X_train, y_train)
y_val_pred = self._estimator.predict(X_val)
forest_pred = self._score_metric(y_val, y_val_pred)
forest = self._estimator.estimators_
selected_trees = list()
tree_list = list(self._estimator.estimators_)
val_scores = list()
......@@ -57,12 +61,13 @@ class SimilarityForestRegressor(BaseEstimator, metaclass=ABCMeta):
tree_list.insert(j, lonely_tree)
val_scores.insert(j, lonely_tree.predict(X_val))
tree_list_bar.update(1)
selected_trees.append(tree_list[found_index])
self._selected_trees.append(tree_list[found_index])
del tree_list[found_index]
del val_scores[found_index]
pruning_forest_bar.update(1)
pruned_forest = list(set(forest) - set(selected_trees))
self._selected_trees = set(self._selected_trees)
pruned_forest = list(set(forest) - self._selected_trees)
self._estimator.estimators_ = pruned_forest
def score(self, X, y):
......
......@@ -13,6 +13,8 @@ from sklearn.metrics import mean_squared_error, accuracy_score
import time
import datetime
import numpy as np
import os
import pickle
class Trainer(object):
......@@ -36,6 +38,7 @@ class Trainer(object):
else classification_score_metric.__name__
self._base_score_metric_name = base_regression_score_metric.__name__ if dataset.task == Task.REGRESSION \
else base_classification_score_metric.__name__
self._selected_trees = ''
@property
def score_metric_name(self):
......@@ -93,6 +96,7 @@ class Trainer(object):
X=self._X_forest,
y=self._y_forest
)
self._selected_trees = model.estimators_
else:
model.fit(
self._X_forest,
......@@ -151,6 +155,13 @@ class Trainer(object):
elif type(model) == OmpForestBinaryClassifier:
model_weights = model._omp
if type(model) in [SimilarityForestRegressor, EnsembleSelectionForestRegressor, KMeansForestRegressor]:
self._selected_trees = model.selected_trees
if len(self._selected_trees) > 0:
with open(os.path.join(models_dir, 'selected_trees.pickle'), 'wb') as output_file:
pickle.dump(self._selected_trees, output_file)
results = ModelRawResults(
model_weights=model_weights,
training_time=self._end_time - self._begin_time,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment