Skip to content
Snippets Groups Projects
Commit 0468df2e authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Feature_importance for multiview only possible

"
parent 21969c99
No related branches found
No related tags found
No related merge requests found
Pipeline #11505 failed
......@@ -30,14 +30,8 @@ def get_feature_importances(result, feature_ids=None, view_names=None,):
feature_importances[classifier_result.view_name] = pd.DataFrame(
index=feature_ids[classifier_result.view_index])
if hasattr(classifier_result.clf, 'feature_importances_'):
print(classifier_result.classifier_name, classifier_result.view_name)
feature_importances[classifier_result.view_name][
classifier_result.classifier_name] = classifier_result.clf.feature_importances_
print(classifier_result.clf.feature_importances_.shape,
feature_importances[classifier_result.view_name][
classifier_result.classifier_name].shape)
else:
feature_importances[classifier_result.view_name][
classifier_result.classifier_name] = np.zeros(
......@@ -62,7 +56,7 @@ def get_feature_importances(result, feature_ids=None, view_names=None,):
def publish_feature_importances(feature_importances, directory, database_name,
feature_stds=None, metric_scores=None, test=False): # pragma: no cover
importance_dfs = []
importance_dfs = [pd.DataFrame()]
std_dfs = []
if not os.path.exists(os.path.join(directory, "feature_importances")):
os.mkdir(os.path.join(directory, "feature_importances"))
......@@ -90,13 +84,17 @@ def publish_feature_importances(feature_importances, directory, database_name,
# columns=feature_std.columns,
# data=np.zeros((1, len(
# feature_std.columns)))))
if "mv" in feature_importances:
importance_dfs.append(feature_importances["mv"].loc[(feature_importances["mv"] != 0).any(axis=1), :])
if len(importance_dfs)>0:
feature_importances_df = pd.concat(importance_dfs)
feature_importances_df = feature_importances_df/feature_importances_df.sum(axis=0)
feature_std_df = pd.concat(std_dfs)
if len(std_dfs)>0:
feature_std_df = pd.concat(std_dfs)
else:
feature_std_df = pd.DataFrame()
if "mv" in feature_importances:
feature_importances_df = pd.concat([feature_importances_df,feature_importances["mv"].loc[(feature_importances["mv"] != 0).any(axis=1), :]], axis=1).fillna(0)
# feature_importances_df = pd.concat([feature_importances_df,feature_importances["mv"].loc[(feature_importances["mv"] != 0).any(axis=1), :]], axis=1).fillna(0)
if feature_stds is not None:
feature_std_df = pd.concat([feature_std_df, feature_stds["mv"]], axis=1,).fillna(0)
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment