Skip to content
Snippets Groups Projects
Commit 96ff6093 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Fix normalize_D missing transpose and fix stage4 result path

parent 34070d2c
No related branches found
No related tags found
1 merge request!20Resolve "integration-sota"
...@@ -123,7 +123,9 @@ class SingleOmpForest(OmpForest): ...@@ -123,7 +123,9 @@ class SingleOmpForest(OmpForest):
forest_predictions = self._base_estimator_predictions(X) forest_predictions = self._base_estimator_predictions(X)
if self._models_parameters.normalize_D: if self._models_parameters.normalize_D:
forest_predictions = forest_predictions.T
forest_predictions /= self._forest_norms forest_predictions /= self._forest_norms
forest_predictions = forest_predictions.T
return self._make_omp_weighted_prediction(forest_predictions, self._omp, self._models_parameters.normalize_weights) return self._make_omp_weighted_prediction(forest_predictions, self._omp, self._models_parameters.normalize_weights)
...@@ -139,7 +141,9 @@ class SingleOmpForest(OmpForest): ...@@ -139,7 +141,9 @@ class SingleOmpForest(OmpForest):
forest_predictions = self._base_estimator_predictions(X).T forest_predictions = self._base_estimator_predictions(X).T
if self._models_parameters.normalize_D: if self._models_parameters.normalize_D:
forest_predictions = forest_predictions.T
forest_predictions /= self._forest_norms forest_predictions /= self._forest_norms
forest_predictions = forest_predictions.T
weights = self._omp.coef_ weights = self._omp.coef_
omp_trees_indices = np.nonzero(weights)[0] omp_trees_indices = np.nonzero(weights)[0]
......
...@@ -37,7 +37,9 @@ class OmpForestBinaryClassifier(SingleOmpForest): ...@@ -37,7 +37,9 @@ class OmpForestBinaryClassifier(SingleOmpForest):
forest_predictions = np.array([tree.predict_proba(X) for tree in self._base_forest_estimator.estimators_]) forest_predictions = np.array([tree.predict_proba(X) for tree in self._base_forest_estimator.estimators_])
if self._models_parameters.normalize_D: if self._models_parameters.normalize_D:
forest_predictions = forest_predictions.T
forest_predictions /= self._forest_norms forest_predictions /= self._forest_norms
forest_predictions = forest_predictions.T
weights = self._omp.coef_ weights = self._omp.coef_
omp_trees_indices = np.nonzero(weights) omp_trees_indices = np.nonzero(weights)
...@@ -119,7 +121,9 @@ class OmpForestMulticlassClassifier(OmpForest): ...@@ -119,7 +121,9 @@ class OmpForestMulticlassClassifier(OmpForest):
forest_predictions = np.array([tree.predict_proba(X) for tree in self._base_forest_estimator.estimators_]).T forest_predictions = np.array([tree.predict_proba(X) for tree in self._base_forest_estimator.estimators_]).T
if self._models_parameters.normalize_D: if self._models_parameters.normalize_D:
forest_predictions = forest_predictions.T
forest_predictions /= self._forest_norms forest_predictions /= self._forest_norms
forest_predictions = forest_predictions.T
label_names = [] label_names = []
preds = [] preds = []
......
...@@ -389,7 +389,7 @@ if __name__ == "__main__": ...@@ -389,7 +389,7 @@ if __name__ == "__main__":
raise ValueError('Score metrics of all experiments must be the same.') raise ValueError('Score metrics of all experiments must be the same.')
experiments_score_metric = base_with_params_experiment_score_metric experiments_score_metric = base_with_params_experiment_score_metric
output_path = os.path.join(args.results_dir, args.dataset_name, 'stage4_fix') output_path = os.path.join(args.results_dir, args.dataset_name, 'stage4')
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True) pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)
Plotter.plot_stage2_losses( Plotter.plot_stage2_losses(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment