Skip to content
Snippets Groups Projects
Commit 0c57cb4e authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Removed subplots and tried fbeta

parent cf34b931
No related branches found
No related tags found
No related merge requests found
<?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4">
<component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output />
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="module" module-name="multiviewmetriclearning" />
<orderEntry type="module" module-name="multiview_generator" />
<orderEntry type="module" module-name="short_projects" />
<orderEntry type="library" name="R User Library" level="project" />
<orderEntry type="library" name="R Skeletons" level="application" />
<orderEntry type="module" module-name="Datasets" />
</component>
</module>
\ No newline at end of file
......@@ -18,7 +18,7 @@ def score(y_true, y_pred, multiclass=False, **kwargs):
try:
beta = kwargs["1"]
except Exception:
beta = 1.0
beta = 10.0
try:
labels = kwargs["2"]
except Exception:
......
......@@ -217,27 +217,27 @@ def plot_2d(data, classifiers_names, nbClassifiers, nbExamples,
plt.close()
### The following part is used to generate an interactive graph.
if use_plotly:
label_index_list = [np.arange(len(labels))] #[np.where(labels==i)[0] for i in np.unique(labels)]
hover_text = [[example_ids[i] + " failed "+ str(stats_iter-data[i,j])+" time(s)"
for j in range(data.shape[1])]
for i in range(data.shape[0]) ]
fig = plotly.subplots.make_subplots(rows=len(label_index_list), cols=1)
for row_index, label_index in enumerate(label_index_list):
label_index_list = np.concatenate([np.where(labels==i)[0] for i in np.unique(labels)]) #[np.where(labels==i)[0] for i in np.unique(labels)]
hover_text = [[example_ids[example_index] + " failed "+ str(stats_iter-data[example_index,classifier_index])+" time(s), labelled "+str(example_index)
for classifier_index in range(data.shape[1])]
for example_index in range(data.shape[0]) ]
fig = plotly.graph_objs.Figure()
# for row_index, label_index in enumerate(label_index_list):
fig.add_trace(plotly.graph_objs.Heatmap(
x=list(classifiers_names),
y=[example_ids[label_ind] for label_ind in label_index],
z=data[label_index, :],
text=hover_text,
y=[example_ids[label_ind] for label_ind in label_index_list],
z=data[label_index_list, :],
text=[hover_text[label_ind] for label_ind in label_index_list],
hoverinfo=["y", "x", "text"],
colorscale="Greys",
colorbar=dict(tickvals=[0, stats_iter],
ticktext=["Always Wrong", "Always Right"]),
reversescale=True), row=row_index+1, col=1)
fig.update_yaxes(title_text="Label "+str(row_index), showticklabels=False, ticks='', row=row_index+1, col=1)
fig.update_xaxes(showticklabels=False, row=row_index+1, col=1)
reversescale=True),)
fig.update_yaxes(title_text="Examples", showticklabels=False, ticks='')
fig.update_xaxes(showticklabels=False,)
fig.update_layout(paper_bgcolor = 'rgba(0,0,0,0)',
plot_bgcolor = 'rgba(0,0,0,0)')
fig.update_xaxes(showticklabels=True, row=len(label_index_list), col=1)
fig.update_xaxes(showticklabels=True, )
plotly.offline.plot(fig, filename=file_name + "error_analysis_2D.html", auto_open=False)
del fig
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment