Skip to content
Snippets Groups Projects
Commit 347b7f5d authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Fix omp predictions in coherence

parent 3f85841e
Branches
No related tags found
1 merge request!22WIP: Resolve "coherence des arbres de predictions"
......@@ -145,16 +145,14 @@ class Trainer(object):
from sklearn.preprocessing import normalize
import itertools
if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor]:
estimators = model.forest
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor,
OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
estimators = model.forest
estimators = np.asarray(estimators)[model._omp.coef_ != 0]
elif type(model) in [RandomForestRegressor, RandomForestClassifier]:
estimators = model.estimators_
predictions = list()
for ti in estimators:
predictions.append(ti.predict(X))
predictions = np.array([tree.predict(X) for tree in estimators])
predictions = normalize(predictions)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment