diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py index 10eea289a2989765fd66bfea5ae7fa940d69de3e..218ba76ddbe7b9d8b68eb39b6e202c8d061107d0 100644 --- a/code/bolsonaro/trainer.py +++ b/code/bolsonaro/trainer.py @@ -96,6 +96,8 @@ class Trainer(object): result = self._regression_score_metric(y_true, y_pred) elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, RandomForestClassifier]: y_pred = model.predict(X) + if type(model) is OmpForestBinaryClassifier: + y_pred = y_pred.round() result = self._classification_score_metric(y_true, y_pred) return result diff --git a/results/breast_cancer/stage1/losses.png b/results/breast_cancer/stage1/losses.png new file mode 100644 index 0000000000000000000000000000000000000000..64e0e14cdab6f518eecbdbf900a89cdb2647397b Binary files /dev/null and b/results/breast_cancer/stage1/losses.png differ