Skip to content
Snippets Groups Projects
Commit 95f543a1 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Fix omp_wo_weights when coupled with normalize_D. Fix missing selected trees...

Fix omp_wo_weights when coupled with normalize_D. Fix missing selected trees saving for omp and random. Update compute results (not done yet). Fix omp_distillation.
parent a98fd932
No related branches found
No related tags found
1 merge request!23Resolve "integration-sota"
......@@ -41,7 +41,7 @@ class OmpForest(BaseEstimator, metaclass=ABCMeta):
# print(set([type(y) for y in y_forest]))
self._base_forest_estimator.fit(X_forest, y_forest)
self._extract_subforest(X_omp,
self.predict_base_estimator(X_forest) if use_distillation else y_omp) # type: OrthogonalMatchingPursuit
self.predict_base_estimator(X_omp) if use_distillation else y_omp) # type: OrthogonalMatchingPursuit
return self
def _extract_subforest(self, X, y):
......@@ -153,11 +153,6 @@ class SingleOmpForest(OmpForest):
"""
forest_predictions = np.array([tree.predict(X) for tree in self._base_forest_estimator.estimators_])
if self._models_parameters.normalize_D:
forest_predictions = forest_predictions.T
forest_predictions /= self._forest_norms
forest_predictions = forest_predictions.T
weights = self._omp.coef_
select_trees = np.mean(forest_predictions[weights != 0], axis=0)
return select_trees
......@@ -19,11 +19,11 @@ class OmpForestBinaryClassifier(SingleOmpForest):
def _check_classes(self, y):
assert len(set(y).difference({-1, 1})) == 0, "Classes for binary classifier must be {-1, +1}"
def fit(self, X_forest, y_forest, X_omp, y_omp):
def fit(self, X_forest, y_forest, X_omp, y_omp, use_distillation=False):
self._check_classes(y_forest)
self._check_classes(y_omp)
return super().fit(X_forest, y_forest, X_omp, y_omp)
return super().fit(X_forest, y_forest, X_omp, y_omp, use_distillation=use_distillation)
def _base_estimator_predictions(self, X):
predictions_0_1 = super()._base_estimator_predictions(X)
......@@ -42,9 +42,6 @@ class OmpForestBinaryClassifier(SingleOmpForest):
forest_predictions = self._base_estimator_predictions(X)
if self._models_parameters.normalize_D:
forest_predictions /= self._forest_norms
weights = self._omp.coef_
omp_trees_predictions = forest_predictions[:, weights != 0]
......
......@@ -155,16 +155,7 @@ class Trainer(object):
return result
def _evaluate_predictions(self, model, X, aggregation_function):
if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
estimators = model.forest
estimators = np.asarray(estimators)[model._omp.coef_ != 0]
elif type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor,
SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]:
estimators = model.selected_trees
elif type(model) in [RandomForestRegressor, RandomForestClassifier]:
estimators = model.estimators_
predictions = np.array([tree.predict(X) for tree in estimators])
predictions = np.array([tree.predict(X) for tree in self._selected_trees])
predictions = normalize(predictions)
......@@ -187,6 +178,10 @@ class Trainer(object):
if type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor,
SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]:
self._selected_trees = model.selected_trees
elif type(model) in [OmpForestRegressor, OmpForestMulticlassClassifier, OmpForestBinaryClassifier]:
self._selected_trees = np.asarray(model.forest)[model._omp.coef_ != 0]
elif type(model) in [RandomForestRegressor, RandomForestClassifier]:
self._selected_trees = model.estimators_
if len(self._selected_trees) > 0:
with open(os.path.join(models_dir, 'selected_trees.pickle'), 'wb') as output_file:
......
......@@ -473,9 +473,10 @@ if __name__ == "__main__":
30 + 1,
endpoint=True)[1:]).astype(np.int)).tolist()"""
extracted_forest_sizes = [4, 7, 11, 14, 18, 22, 25, 29, 32, 36, 40, 43, 47, 50, 54, 58, 61, 65, 68, 72, 76, 79, 83, 86, 90, 94, 97, 101, 104, 108]
#extracted_forest_sizes = [4, 7, 11, 14, 18, 22, 25, 29, 32, 36, 40, 43, 47, 50, 54, 58, 61, 65, 68, 72, 76, 79, 83, 86, 90, 94, 97, 101, 104, 108]
extracted_forest_sizes = [str(forest_size) for forest_size in extracted_forest_sizes]
#extracted_forest_sizes = [str(forest_size) for forest_size in extracted_forest_sizes]
extracted_forest_sizes= list()
# base_with_params
logger.info('Loading base_with_params experiment scores...')
......@@ -508,8 +509,10 @@ if __name__ == "__main__":
for i in range(3, len(args.experiment_ids)):
if 'kmeans' in args.experiment_ids[i]:
label = 'kmeans'
elif 'similarity' in args.experiment_ids[i]:
label = 'similarity'
elif 'similarity_similarities' in args.experiment_ids[i]:
label = 'similarity_similarities'
elif 'similarity_predictions' in args.experiment_ids[i]:
label = 'similarity_predictions'
elif 'ensemble' in args.experiment_ids[i]:
label = 'ensemble'
else:
......@@ -528,7 +531,7 @@ if __name__ == "__main__":
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)
Plotter.plot_stage2_losses(
file_path=output_path + os.sep + f"losses_{'-'.join(all_labels)}_test.png",
file_path=output_path + os.sep + f"losses_{'-'.join(all_labels)}_test_train,dev.png",
all_experiment_scores=all_scores,
all_labels=all_labels,
x_value=with_params_extracted_forest_sizes,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment