Skip to content
Snippets Groups Projects
Commit ecb1b89c authored by Kossi Kossivi's avatar Kossi Kossivi
Browse files

Fix feature_importances issue due to mismatch between classifier name and...

Fix feature_importances issue due to mismatch between classifier name and registered name in metrics_scores
parent a99d6031
No related branches found
No related tags found
No related merge requests found
......@@ -44,19 +44,22 @@ def get_feature_importances(result, feature_ids=None, view_names=None,):
v_feature_id]
feature_importances["mv"] = pd.DataFrame(index=feat_ids)
if hasattr(classifier_result.clf, 'feature_importances_'):
feature_importances["mv"][classifier_result.classifier_name] = classifier_result.clf.feature_importances_
feature_importances["mv"][classifier_result.get_classifier_name()] = classifier_result.clf.feature_importances_
else:
# HACK: Assigning a default features importances values to classifier that hasn't feature_importances_
# attribute (eg: Linear Late Fusion)
feature_importances["mv"][classifier_result.get_classifier_name()] = np.zeros(len(feature_importances["mv"].index))
return feature_importances
def publish_feature_importances(feature_importances, directory, database_name,
feature_stds=None, metric_scores=None): # pragma: no cover
# TODO: Manage the case with NAN values
importance_dfs = []
std_dfs = []
if not os.path.exists(os.path.join(directory, "feature_importances")):
os.mkdir(os.path.join(directory, "feature_importances"))
for view_name, feature_importance in feature_importances.items():
if view_name!="mv":
if feature_stds is not None:
feature_std = feature_stds[view_name]
else:
......@@ -65,33 +68,22 @@ def publish_feature_importances(feature_importances, directory, database_name,
columns=feature_importance.columns)
feature_std = feature_std.loc[feature_importance.index]
if view_name == "mv":
importance_dfs.append(feature_importance)
std_dfs.append(feature_std)
else:
importance_dfs.append(feature_importance.set_index(
pd.Index([view_name + "-" + ind for ind in list(feature_importance.index)])))
importance_dfs.append(feature_importance.set_index(pd.Index([view_name+"-"+ind for ind in list(feature_importance.index)])))
# importance_dfs.append(pd.DataFrame(index=[view_name+"-br"],
# columns=feature_importance.columns,
# data=np.zeros((1, len(
# feature_importance.columns)))))
std_dfs.append(feature_std.set_index(pd.Index([view_name + "-" + ind
for ind
in list(feature_std.index)])))
# std_dfs.append(pd.DataFrame(index=[view_name + "-br"],
# columns=feature_std.columns,
# data=np.zeros((1, len(
# feature_std.columns)))))
if len(importance_dfs) > 0:
feature_importances_df = pd.concat(importance_dfs)
feature_importances_df = feature_importances_df / feature_importances_df.sum(axis=0)
feature_std_df = pd.concat(std_dfs)
if "mv" in feature_importances:
feature_importances_df = pd.concat([feature_importances_df,feature_importances["mv"].loc[(feature_importances["mv"] != 0).any(axis=1), :]], axis=1).fillna(0)
if feature_stds is not None:
feature_std_df = pd.concat([feature_std_df, feature_stds["mv"]], axis=1,).fillna(0)
else:
fake = pd.DataFrame(data=np.zeros((feature_importances_df.shape[0], feature_importances["mv"].shape[1])),
index=feature_importances_df.index,
columns=feature_importances["mv"].columns).fillna(0)
feature_std_df = pd.concat([feature_std_df, fake], axis=1,).fillna(0)
plot_feature_importances(os.path.join(directory, "feature_importances",
database_name), feature_importances_df, feature_std_df)
if metric_scores is not None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment