diff --git a/summit/multiview_platform/result_analysis/feature_importances.py b/summit/multiview_platform/result_analysis/feature_importances.py index f09cc060cc8882878cb862d50d7faf9bd810f3f8..d95fb6e4faa4900139bb3ff4f00df694d9673fbc 100644 --- a/summit/multiview_platform/result_analysis/feature_importances.py +++ b/summit/multiview_platform/result_analysis/feature_importances.py @@ -25,13 +25,11 @@ def get_feature_importances(result, feature_ids=None, view_names=None,): """ feature_importances = {} for classifier_result in result: - print(classifier_result.classifier_name) if isinstance(classifier_result, MonoviewResult): if classifier_result.view_name not in feature_importances: feature_importances[classifier_result.view_name] = pd.DataFrame( index=feature_ids[classifier_result.view_index]) if hasattr(classifier_result.clf, 'feature_importances_'): - print(classifier_result.classifier_name) feature_importances[classifier_result.view_name][ classifier_result.classifier_name] = classifier_result.clf.feature_importances_ else: @@ -64,7 +62,6 @@ def publish_feature_importances(feature_importances, directory, database_name, os.mkdir(os.path.join(directory, "feature_importances")) for view_name, feature_importance in feature_importances.items(): if view_name!="mv": - if feature_stds is not None: feature_std = feature_stds[view_name] else: @@ -75,37 +72,35 @@ def publish_feature_importances(feature_importances, directory, database_name, importance_dfs.append(feature_importance.set_index(pd.Index([view_name+"-"+ind for ind in list(feature_importance.index)])).fillna(0)) - # importance_dfs.append(pd.DataFrame(index=[view_name+"-br"], - # columns=feature_importance.columns, - # data=np.zeros((1, len( - # feature_importance.columns))))) std_dfs.append(feature_std.set_index(pd.Index([view_name+"-"+ind for ind in list(feature_std.index)])).fillna(0)) - # std_dfs.append(pd.DataFrame(index=[view_name + "-br"], - # columns=feature_std.columns, - # data=np.zeros((1, len( - # feature_std.columns))))) if "mv" in feature_importances: importance_dfs.append(feature_importances["mv"].loc[(feature_importances["mv"] != 0).any(axis=1), :].fillna(0)) if len(importance_dfs)>0: - print(importance_dfs) indices=None + columns = None for df in importance_dfs: if indices is None: indices = list(df.index) else: indices += [ind for ind in df.index if ind not in indices] - feat_imp_df = pd.DataFrame(index=indices) - feature_importances_df = pd.concat([feat_imp_df]+importance_dfs, axis=1) - print(feature_importances_df) + if columns is None: + columns = list(df.columns) + else: + columns += [col for col in df.columns if col not in columns] + feature_importances_df = pd.DataFrame(index=indices, columns=columns) + for df in importance_dfs: + feature_importances_df = feature_importances_df.combine_first(df) + feature_importances_df = feature_importances_df/feature_importances_df.sum(axis=0) + feature_std_df = pd.DataFrame(index=indices, columns=columns) if len(std_dfs)>0: - feature_std_df = pd.concat(std_dfs) + for df in std_dfs: + feature_std_df = feature_std_df.combine_first(df) else: feature_std_df = pd.DataFrame() if "mv" in feature_importances: - # feature_importances_df = pd.concat([feature_importances_df,feature_importances["mv"].loc[(feature_importances["mv"] != 0).any(axis=1), :]], axis=1).fillna(0) if feature_stds is not None: feature_std_df = pd.concat([feature_std_df, feature_stds["mv"]], axis=1,).fillna(0) else: diff --git a/summit/multiview_platform/utils/compression.py b/summit/multiview_platform/utils/compression.py index 0a11b8950eab352514d261b93b6cc41642d7ee08..7ea525f81862a9561fb343cb045fd0ce13c467ae 100644 --- a/summit/multiview_platform/utils/compression.py +++ b/summit/multiview_platform/utils/compression.py @@ -43,10 +43,14 @@ def remove_compressed(exp_path): if __name__=="__main__": - # for dir in os.listdir("/home/baptiste/Documents/Gitwork/summit/results/"): - # print(dir) - # for exp in os.listdir((os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir))): - # print("\t", exp) - # explore_files(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp)) - explore_files("/home/baptiste/Documents/Gitwork/biobanq_covid_expes/results/") - # simplify_plotly("/home//baptiste/Documents/Gitwork/summit/results/hepatitis/debug_started_2022_03_16-15_06_55__/hepatitis-mean_on_10_iter-balanced_accuracy_p.html") + for dir in os.listdir("/home/baptiste/Documents/Gitwork/summit/results/"): + print(dir) + for exp in os.listdir((os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir))): + print("\t", exp) + if os.path.isdir(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp)): + explore_files(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp)) + # # explore_files("/home/baptiste/Documents/Gitwork/biobanq_covid_expes/results/") + # explore_files("/home/baptiste/Documents/Gitwork/summit/results/tnbc_mazid/debug_started_2023_03_24-11_27_46_thesis") + # explore_files( + # "/home/baptiste/Documents/Gitwork/summit/results/lives_thesis_EMF/debug_started_2023_03_24-10_02_21_thesis") + # # simplify_plotly("/home//baptiste/Documents/Gitwork/summit/results/hepatitis/debug_started_2022_03_16-15_06_55__/hepatitis-mean_on_10_iter-balanced_accuracy_p.html")