Skip to content
Snippets Groups Projects
Commit 96ff6093 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Fix normalize_D missing transpose and fix stage4 result path

parent 34070d2c
Branches
No related tags found
1 merge request!20Resolve "integration-sota"
......@@ -123,7 +123,9 @@ class SingleOmpForest(OmpForest):
forest_predictions = self._base_estimator_predictions(X)
if self._models_parameters.normalize_D:
forest_predictions = forest_predictions.T
forest_predictions /= self._forest_norms
forest_predictions = forest_predictions.T
return self._make_omp_weighted_prediction(forest_predictions, self._omp, self._models_parameters.normalize_weights)
......@@ -139,7 +141,9 @@ class SingleOmpForest(OmpForest):
forest_predictions = self._base_estimator_predictions(X).T
if self._models_parameters.normalize_D:
forest_predictions = forest_predictions.T
forest_predictions /= self._forest_norms
forest_predictions = forest_predictions.T
weights = self._omp.coef_
omp_trees_indices = np.nonzero(weights)[0]
......
......@@ -37,7 +37,9 @@ class OmpForestBinaryClassifier(SingleOmpForest):
forest_predictions = np.array([tree.predict_proba(X) for tree in self._base_forest_estimator.estimators_])
if self._models_parameters.normalize_D:
forest_predictions = forest_predictions.T
forest_predictions /= self._forest_norms
forest_predictions = forest_predictions.T
weights = self._omp.coef_
omp_trees_indices = np.nonzero(weights)
......@@ -119,7 +121,9 @@ class OmpForestMulticlassClassifier(OmpForest):
forest_predictions = np.array([tree.predict_proba(X) for tree in self._base_forest_estimator.estimators_]).T
if self._models_parameters.normalize_D:
forest_predictions = forest_predictions.T
forest_predictions /= self._forest_norms
forest_predictions = forest_predictions.T
label_names = []
preds = []
......
......@@ -389,7 +389,7 @@ if __name__ == "__main__":
raise ValueError('Score metrics of all experiments must be the same.')
experiments_score_metric = base_with_params_experiment_score_metric
output_path = os.path.join(args.results_dir, args.dataset_name, 'stage4_fix')
output_path = os.path.join(args.results_dir, args.dataset_name, 'stage4')
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)
Plotter.plot_stage2_losses(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment