From 72465f530ad6c06274a6557e1c7505a803b6320c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Bouscarrat?= <leo.bouscarrat@euranova.eu> Date: Thu, 5 Mar 2020 11:52:59 +0100 Subject: [PATCH] Add divider --- code/bolsonaro/models/omp_forest.py | 4 ++-- code/bolsonaro/models/omp_forest_classifier.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/code/bolsonaro/models/omp_forest.py b/code/bolsonaro/models/omp_forest.py index 2263e04..5b211b0 100644 --- a/code/bolsonaro/models/omp_forest.py +++ b/code/bolsonaro/models/omp_forest.py @@ -33,8 +33,8 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta): # sklearn baseestimator api methods def fit(self, X_forest, y_forest, X_omp, y_omp): - print(y_forest.shape) - print(set([type(y) for y in y_forest])) + # print(y_forest.shape) + # print(set([type(y) for y in y_forest])) self._base_forest_estimator.fit(X_forest, y_forest) self._extract_subforest(X_omp, y_omp) # type: OrthogonalMatchingPursuit return self diff --git a/code/bolsonaro/models/omp_forest_classifier.py b/code/bolsonaro/models/omp_forest_classifier.py index 36d12be..ccaf3eb 100644 --- a/code/bolsonaro/models/omp_forest_classifier.py +++ b/code/bolsonaro/models/omp_forest_classifier.py @@ -129,7 +129,7 @@ class OmpForestMulticlassClassifier(OmpForest): omp_trees_indices = np.nonzero(weights) label_names.append(class_label) atoms_binary = (forest_predictions[num_class].T - 0.5) * 2 # centré réduit de 0/1 à -1/1 - preds.append(np.sum(atoms_binary[omp_trees_indices], axis=0)) + preds.append(np.sum(atoms_binary[omp_trees_indices], axis=0)/len(omp_trees_indices)) num_class += 1 preds = np.array(preds).T -- GitLab