Skip to content
Snippets Groups Projects
Commit d0e487aa authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Feat importance

parent 2c98dbd6
No related branches found
No related tags found
No related merge requests found
...@@ -31,8 +31,13 @@ def get_feature_importances(result, feature_ids=None, view_names=None,): ...@@ -31,8 +31,13 @@ def get_feature_importances(result, feature_ids=None, view_names=None,):
index=feature_ids[classifier_result.view_index]) index=feature_ids[classifier_result.view_index])
if hasattr(classifier_result.clf, 'feature_importances_'): if hasattr(classifier_result.clf, 'feature_importances_'):
print(classifier_result.classifier_name, classifier_result.view_name) print(classifier_result.classifier_name, classifier_result.view_name)
feature_importances[classifier_result.view_name][ feature_importances[classifier_result.view_name][
classifier_result.classifier_name] = classifier_result.clf.feature_importances_ classifier_result.classifier_name] = classifier_result.clf.feature_importances_
print(classifier_result.clf.feature_importances_.shape,
feature_importances[classifier_result.view_name][
classifier_result.classifier_name].shape)
else: else:
feature_importances[classifier_result.view_name][ feature_importances[classifier_result.view_name][
classifier_result.classifier_name] = np.zeros( classifier_result.classifier_name] = np.zeros(
...@@ -149,7 +154,6 @@ def plot_feature_relevance(file_name, feature_importance, ...@@ -149,7 +154,6 @@ def plot_feature_relevance(file_name, feature_importance,
if isinstance(score_df, dict): if isinstance(score_df, dict):
score_df = score_df["mean"] score_df = score_df["mean"]
for score in score_df.columns: for score in score_df.columns:
print(score)
if len(score.split("-"))>1: if len(score.split("-"))>1:
algo, view = score.split("-") algo, view = score.split("-")
list_ind = [ind for ind in feature_importance.index if ind.startswith(view)] list_ind = [ind for ind in feature_importance.index if ind.startswith(view)]
......
...@@ -43,9 +43,11 @@ def remove_compressed(exp_path): ...@@ -43,9 +43,11 @@ def remove_compressed(exp_path):
if __name__=="__main__": if __name__=="__main__":
for dir in os.listdir("/home/baptiste/Documents/Gitwork/summit/results/"): # for dir in os.listdir("/home/baptiste/Documents/Gitwork/summit/results/"):
print(dir) # print(dir)
for exp in os.listdir((os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir))): # for exp in os.listdir((os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir))):
print("\t", exp) # print("\t", exp)
explore_files(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp)) # explore_files(os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", dir, exp))
explore_files(
os.path.join("/home/baptiste/Documents/Gitwork/summit/results/", "mage_dset", "debug_started_2022_12_13-10_15_20_th"))
# simplify_plotly("/home//baptiste/Documents/Gitwork/summit/results/hepatitis/debug_started_2022_03_16-15_06_55__/hepatitis-mean_on_10_iter-balanced_accuracy_p.html") # simplify_plotly("/home//baptiste/Documents/Gitwork/summit/results/hepatitis/debug_started_2022_03_16-15_06_55__/hepatitis-mean_on_10_iter-balanced_accuracy_p.html")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment