diff --git a/code/bolsonaro/models/omp_forest.py b/code/bolsonaro/models/omp_forest.py index 2263e04943918d871a59127184f1f6f2da5bcaa2..5b211b0d2900a42ac5fedf95d1c98bee4f1fa5e9 100644 --- a/code/bolsonaro/models/omp_forest.py +++ b/code/bolsonaro/models/omp_forest.py @@ -33,8 +33,8 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta): # sklearn baseestimator api methods def fit(self, X_forest, y_forest, X_omp, y_omp): - print(y_forest.shape) - print(set([type(y) for y in y_forest])) + # print(y_forest.shape) + # print(set([type(y) for y in y_forest])) self._base_forest_estimator.fit(X_forest, y_forest) self._extract_subforest(X_omp, y_omp) # type: OrthogonalMatchingPursuit return self diff --git a/code/bolsonaro/models/omp_forest_classifier.py b/code/bolsonaro/models/omp_forest_classifier.py index 36d12be6727c25fcc029c13b1a13490f24be1295..ccaf3ebc2b630798f62ab17a1285ab28b366ed95 100644 --- a/code/bolsonaro/models/omp_forest_classifier.py +++ b/code/bolsonaro/models/omp_forest_classifier.py @@ -129,7 +129,7 @@ class OmpForestMulticlassClassifier(OmpForest): omp_trees_indices = np.nonzero(weights) label_names.append(class_label) atoms_binary = (forest_predictions[num_class].T - 0.5) * 2 # centré réduit de 0/1 à -1/1 - preds.append(np.sum(atoms_binary[omp_trees_indices], axis=0)) + preds.append(np.sum(atoms_binary[omp_trees_indices], axis=0)/len(omp_trees_indices)) num_class += 1 preds = np.array(preds).T