From d5f2b26b158b1b8f4efb46dd42d44fa4357018e7 Mon Sep 17 00:00:00 2001 From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr> Date: Tue, 22 Mar 2022 11:15:53 -0400 Subject: [PATCH] Added feature relevance --- .../result_analysis/execution.py | 4 ++-- .../result_analysis/feature_importances.py | 22 +++++++++++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/summit/multiview_platform/result_analysis/execution.py b/summit/multiview_platform/result_analysis/execution.py index 279891e7..7e046df7 100644 --- a/summit/multiview_platform/result_analysis/execution.py +++ b/summit/multiview_platform/result_analysis/execution.py @@ -84,7 +84,7 @@ def analyze_iterations(results, benchmark_argument_dictionaries, stats_iter, sample_errors = get_sample_errors(labels, result) feature_importances = get_feature_importances(result, feature_ids=feature_ids, - view_names=view_names) + view_names=view_names,) durations = get_duration(result) directory = arguments["directory"] @@ -98,7 +98,7 @@ def analyze_iterations(results, benchmark_argument_dictionaries, stats_iter, publish_sample_errors(sample_errors, directory, database_name, labels_names, sample_ids, labels) publish_feature_importances(feature_importances, directory, - database_name) + database_name, metric_scores=metrics_scores) plot_durations(durations, directory, database_name) iter_results["metrics_scores"][iter_index] = metrics_scores diff --git a/summit/multiview_platform/result_analysis/feature_importances.py b/summit/multiview_platform/result_analysis/feature_importances.py index c3f234f6..0735c6ea 100644 --- a/summit/multiview_platform/result_analysis/feature_importances.py +++ b/summit/multiview_platform/result_analysis/feature_importances.py @@ -7,7 +7,7 @@ import plotly from ..monoview.monoview_utils import MonoviewResult -def get_feature_importances(result, feature_ids=None, view_names=None): +def get_feature_importances(result, feature_ids=None, view_names=None,): r"""Extracts the feature importance from the monoview results and stores them in a dictionnary : feature_importance[view_name] is a pandas.DataFrame of size n_feature*n_clf @@ -49,7 +49,7 @@ def get_feature_importances(result, feature_ids=None, view_names=None): def publish_feature_importances(feature_importances, directory, database_name, - feature_stds=None): # pragma: no cover + feature_stds=None, metric_scores=None): # pragma: no cover importance_dfs = [] std_dfs = [] if not os.path.exists(os.path.join(directory, "feature_importances")): @@ -94,6 +94,9 @@ def publish_feature_importances(feature_importances, directory, database_name, feature_std_df = pd.concat([feature_std_df, fake], axis=1,).fillna(0) plot_feature_importances(os.path.join(directory, "feature_importances", database_name), feature_importances_df, feature_std_df) + if metric_scores is not None: + plot_feature_relevance(os.path.join(directory, "feature_importances", + database_name), feature_importances_df, feature_std_df, metric_scores) def plot_feature_importances(file_name, feature_importance, @@ -125,3 +128,18 @@ def plot_feature_importances(file_name, feature_importance, plotly.offline.plot(fig, filename=file_name + ".html", auto_open=False) del fig + + +def plot_feature_relevance(file_name, feature_importance, + feature_std, metric_scores): # pragma: no cover + for metric, score_df in metric_scores.items(): + if metric.endswith("*"): + for score in score_df.columns: + if len(score.split("-"))>1: + algo, view = score.split("-") + feature_importance[algo].loc[[ind for ind in feature_importance.index if ind.startswith(view)]]*=score_df[score]['test'] + else: + feature_importance[score] *= score_df[score]['test'] + file_name+="_relevance" + plot_feature_importances(file_name, feature_importance, + feature_std) -- GitLab