diff --git a/summit/multiview_platform/result_analysis/execution.py b/summit/multiview_platform/result_analysis/execution.py index 279891e7d30ea3fa2262e25eea013a883d61db52..7e046df70434a1e1a99213120b84ff89e23a6e24 100644 --- a/summit/multiview_platform/result_analysis/execution.py +++ b/summit/multiview_platform/result_analysis/execution.py @@ -84,7 +84,7 @@ def analyze_iterations(results, benchmark_argument_dictionaries, stats_iter, sample_errors = get_sample_errors(labels, result) feature_importances = get_feature_importances(result, feature_ids=feature_ids, - view_names=view_names) + view_names=view_names,) durations = get_duration(result) directory = arguments["directory"] @@ -98,7 +98,7 @@ def analyze_iterations(results, benchmark_argument_dictionaries, stats_iter, publish_sample_errors(sample_errors, directory, database_name, labels_names, sample_ids, labels) publish_feature_importances(feature_importances, directory, - database_name) + database_name, metric_scores=metrics_scores) plot_durations(durations, directory, database_name) iter_results["metrics_scores"][iter_index] = metrics_scores diff --git a/summit/multiview_platform/result_analysis/feature_importances.py b/summit/multiview_platform/result_analysis/feature_importances.py index c3f234f6750db4d96e53a747a9c3fdc70373e634..0735c6eaf12ef953957f60261e12c9767e2a357b 100644 --- a/summit/multiview_platform/result_analysis/feature_importances.py +++ b/summit/multiview_platform/result_analysis/feature_importances.py @@ -7,7 +7,7 @@ import plotly from ..monoview.monoview_utils import MonoviewResult -def get_feature_importances(result, feature_ids=None, view_names=None): +def get_feature_importances(result, feature_ids=None, view_names=None,): r"""Extracts the feature importance from the monoview results and stores them in a dictionnary : feature_importance[view_name] is a pandas.DataFrame of size n_feature*n_clf @@ -49,7 +49,7 @@ def get_feature_importances(result, feature_ids=None, view_names=None): def publish_feature_importances(feature_importances, directory, database_name, - feature_stds=None): # pragma: no cover + feature_stds=None, metric_scores=None): # pragma: no cover importance_dfs = [] std_dfs = [] if not os.path.exists(os.path.join(directory, "feature_importances")): @@ -94,6 +94,9 @@ def publish_feature_importances(feature_importances, directory, database_name, feature_std_df = pd.concat([feature_std_df, fake], axis=1,).fillna(0) plot_feature_importances(os.path.join(directory, "feature_importances", database_name), feature_importances_df, feature_std_df) + if metric_scores is not None: + plot_feature_relevance(os.path.join(directory, "feature_importances", + database_name), feature_importances_df, feature_std_df, metric_scores) def plot_feature_importances(file_name, feature_importance, @@ -125,3 +128,18 @@ def plot_feature_importances(file_name, feature_importance, plotly.offline.plot(fig, filename=file_name + ".html", auto_open=False) del fig + + +def plot_feature_relevance(file_name, feature_importance, + feature_std, metric_scores): # pragma: no cover + for metric, score_df in metric_scores.items(): + if metric.endswith("*"): + for score in score_df.columns: + if len(score.split("-"))>1: + algo, view = score.split("-") + feature_importance[algo].loc[[ind for ind in feature_importance.index if ind.startswith(view)]]*=score_df[score]['test'] + else: + feature_importance[score] *= score_df[score]['test'] + file_name+="_relevance" + plot_feature_importances(file_name, feature_importance, + feature_std)