Skip to content
Snippets Groups Projects
Commit c6160646 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Fix _evaluate_predictions estimators setting for SOTA methods using new selected_trees attribute

parent 8b7fe366
No related branches found
1 merge request!23Resolve "integration-sota"
...@@ -154,10 +154,12 @@ class Trainer(object): ...@@ -154,10 +154,12 @@ class Trainer(object):
return result return result
def _evaluate_predictions(self, model, X, aggregation_function): def _evaluate_predictions(self, model, X, aggregation_function):
if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
OmpForestBinaryClassifier, OmpForestMulticlassClassifier, SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]:
estimators = model.forest estimators = model.forest
estimators = np.asarray(estimators)[model._omp.coef_ != 0] estimators = np.asarray(estimators)[model._omp.coef_ != 0]
elif type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor,
SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]:
estimators = model.selected_trees
elif type(model) in [RandomForestRegressor, RandomForestClassifier]: elif type(model) in [RandomForestRegressor, RandomForestClassifier]:
estimators = model.estimators_ estimators = model.estimators_
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment