diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py index cb1702d0b7e49381fc4c881476a3d5797e1d31c4..fc289afd1f8301197f5b1dd8be3bb134deca4a91 100644 --- a/code/bolsonaro/trainer.py +++ b/code/bolsonaro/trainer.py @@ -109,6 +109,7 @@ class Trainer(object): y_pred = model.predict_no_weights(X) if type(model) is OmpForestBinaryClassifier: y_pred = np.sign(y_pred) + y_pred = np.where(y_pred==0, 1, y_pred) result = self._classification_score_metric(y_true, y_pred) return result