Skip to content
Snippets Groups Projects
Commit 347b7f5d authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Fix omp predictions in coherence

parent 3f85841e
No related branches found
No related tags found
1 merge request!22WIP: Resolve "coherence des arbres de predictions"
This commit is part of merge request !22. Comments created here will be created in the context of that merge request.
...@@ -145,16 +145,14 @@ class Trainer(object): ...@@ -145,16 +145,14 @@ class Trainer(object):
from sklearn.preprocessing import normalize from sklearn.preprocessing import normalize
import itertools import itertools
if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor]: if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor,
estimators = model.forest OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
estimators = model.forest estimators = model.forest
estimators = np.asarray(estimators)[model._omp.coef_ != 0]
elif type(model) in [RandomForestRegressor, RandomForestClassifier]: elif type(model) in [RandomForestRegressor, RandomForestClassifier]:
estimators = model.estimators_ estimators = model.estimators_
predictions = list() predictions = np.array([tree.predict(X) for tree in estimators])
for ti in estimators:
predictions.append(ti.predict(X))
predictions = normalize(predictions) predictions = normalize(predictions)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment