Commit 72465f53 authored by Léo Bouscarrat's avatar Léo Bouscarrat
Browse files

Add divider

parent 1a22e391
......@@ -33,8 +33,8 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
# sklearn baseestimator api methods
def fit(self, X_forest, y_forest, X_omp, y_omp):
print(y_forest.shape)
print(set([type(y) for y in y_forest]))
# print(y_forest.shape)
# print(set([type(y) for y in y_forest]))
self._base_forest_estimator.fit(X_forest, y_forest)
self._extract_subforest(X_omp, y_omp) # type: OrthogonalMatchingPursuit
return self
......
......@@ -129,7 +129,7 @@ class OmpForestMulticlassClassifier(OmpForest):
omp_trees_indices = np.nonzero(weights)
label_names.append(class_label)
atoms_binary = (forest_predictions[num_class].T - 0.5) * 2 # centré réduit de 0/1 à -1/1
preds.append(np.sum(atoms_binary[omp_trees_indices], axis=0))
preds.append(np.sum(atoms_binary[omp_trees_indices], axis=0)/len(omp_trees_indices))
num_class += 1
preds = np.array(preds).T
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment