Skip to content
Snippets Groups Projects
Commit 61d553a5 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Feature importance correction

parent a8ffedcb
No related branches found
No related tags found
No related merge requests found
Pipeline #11639 failed
...@@ -25,13 +25,11 @@ def get_feature_importances(result, feature_ids=None, view_names=None,): ...@@ -25,13 +25,11 @@ def get_feature_importances(result, feature_ids=None, view_names=None,):
""" """
feature_importances = {} feature_importances = {}
for classifier_result in result: for classifier_result in result:
print(classifier_result.classifier_name)
if isinstance(classifier_result, MonoviewResult): if isinstance(classifier_result, MonoviewResult):
if classifier_result.view_name not in feature_importances: if classifier_result.view_name not in feature_importances:
feature_importances[classifier_result.view_name] = pd.DataFrame( feature_importances[classifier_result.view_name] = pd.DataFrame(
index=feature_ids[classifier_result.view_index]) index=feature_ids[classifier_result.view_index])
if hasattr(classifier_result.clf, 'feature_importances_'): if hasattr(classifier_result.clf, 'feature_importances_'):
print(classifier_result.classifier_name)
feature_importances[classifier_result.view_name][ feature_importances[classifier_result.view_name][
classifier_result.classifier_name] = classifier_result.clf.feature_importances_ classifier_result.classifier_name] = classifier_result.clf.feature_importances_
else: else:
...@@ -64,7 +62,6 @@ def publish_feature_importances(feature_importances, directory, database_name, ...@@ -64,7 +62,6 @@ def publish_feature_importances(feature_importances, directory, database_name,
os.mkdir(os.path.join(directory, "feature_importances")) os.mkdir(os.path.join(directory, "feature_importances"))
for view_name, feature_importance in feature_importances.items(): for view_name, feature_importance in feature_importances.items():
if view_name!="mv": if view_name!="mv":
if feature_stds is not None: if feature_stds is not None:
feature_std = feature_stds[view_name] feature_std = feature_stds[view_name]
else: else:
...@@ -75,37 +72,35 @@ def publish_feature_importances(feature_importances, directory, database_name, ...@@ -75,37 +72,35 @@ def publish_feature_importances(feature_importances, directory, database_name,
importance_dfs.append(feature_importance.set_index(pd.Index([view_name+"-"+ind for ind in list(feature_importance.index)])).fillna(0)) importance_dfs.append(feature_importance.set_index(pd.Index([view_name+"-"+ind for ind in list(feature_importance.index)])).fillna(0))
# importance_dfs.append(pd.DataFrame(index=[view_name+"-br"],
# columns=feature_importance.columns,
# data=np.zeros((1, len(
# feature_importance.columns)))))
std_dfs.append(feature_std.set_index(pd.Index([view_name+"-"+ind std_dfs.append(feature_std.set_index(pd.Index([view_name+"-"+ind
for ind for ind
in list(feature_std.index)])).fillna(0)) in list(feature_std.index)])).fillna(0))
# std_dfs.append(pd.DataFrame(index=[view_name + "-br"],
# columns=feature_std.columns,
# data=np.zeros((1, len(
# feature_std.columns)))))
if "mv" in feature_importances: if "mv" in feature_importances:
importance_dfs.append(feature_importances["mv"].loc[(feature_importances["mv"] != 0).any(axis=1), :].fillna(0)) importance_dfs.append(feature_importances["mv"].loc[(feature_importances["mv"] != 0).any(axis=1), :].fillna(0))
if len(importance_dfs)>0: if len(importance_dfs)>0:
print(importance_dfs)
indices=None indices=None
columns = None
for df in importance_dfs: for df in importance_dfs:
if indices is None: if indices is None:
indices = list(df.index) indices = list(df.index)
else: else:
indices += [ind for ind in df.index if ind not in indices] indices += [ind for ind in df.index if ind not in indices]
feat_imp_df = pd.DataFrame(index=indices) if columns is None:
feature_importances_df = pd.concat([feat_imp_df]+importance_dfs, axis=1) columns = list(df.columns)
print(feature_importances_df) else:
columns += [col for col in df.columns if col not in columns]
feature_importances_df = pd.DataFrame(index=indices, columns=columns)
for df in importance_dfs:
feature_importances_df = feature_importances_df.combine_first(df)
feature_importances_df = feature_importances_df/feature_importances_df.sum(axis=0) feature_importances_df = feature_importances_df/feature_importances_df.sum(axis=0)
feature_std_df = pd.DataFrame(index=indices, columns=columns)
if len(std_dfs)>0: if len(std_dfs)>0:
feature_std_df = pd.concat(std_dfs) for df in std_dfs:
feature_std_df = feature_std_df.combine_first(df)
else: else:
feature_std_df = pd.DataFrame() feature_std_df = pd.DataFrame()
if "mv" in feature_importances: if "mv" in feature_importances:
# feature_importances_df = pd.concat([feature_importances_df,feature_importances["mv"].loc[(feature_importances["mv"] != 0).any(axis=1), :]], axis=1).fillna(0)
if feature_stds is not None: if feature_stds is not None:
feature_std_df = pd.concat([feature_std_df, feature_stds["mv"]], axis=1,).fillna(0) feature_std_df = pd.concat([feature_std_df, feature_stds["mv"]], axis=1,).fillna(0)
else: else:
......
...@@ -43,10 +43,14 @@ def remove_compressed(exp_path): ...@@ -43,10 +43,14 @@ def remove_compressed(exp_path):
if __name__=="__main__": if __name__=="__main__":
# for dir in os.listdir("/home/baptiste/Documents/Gitwork/summit/results/"): for dir in os.listdir("/home/baptiste/Documents/Gitwork/summit/results/"):
# print(dir) print(dir)
# for exp in os.listdir((os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir))): for exp in os.listdir((os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir))):
# print("\t", exp) print("\t", exp)
# explore_files(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp)) if os.path.isdir(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp)):
explore_files("/home/baptiste/Documents/Gitwork/biobanq_covid_expes/results/") explore_files(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp))
# simplify_plotly("/home//baptiste/Documents/Gitwork/summit/results/hepatitis/debug_started_2022_03_16-15_06_55__/hepatitis-mean_on_10_iter-balanced_accuracy_p.html") # # explore_files("/home/baptiste/Documents/Gitwork/biobanq_covid_expes/results/")
# explore_files("/home/baptiste/Documents/Gitwork/summit/results/tnbc_mazid/debug_started_2023_03_24-11_27_46_thesis")
# explore_files(
# "/home/baptiste/Documents/Gitwork/summit/results/lives_thesis_EMF/debug_started_2023_03_24-10_02_21_thesis")
# # simplify_plotly("/home//baptiste/Documents/Gitwork/summit/results/hepatitis/debug_started_2022_03_16-15_06_55__/hepatitis-mean_on_10_iter-balanced_accuracy_p.html")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment