Skip to content
Snippets Groups Projects
Commit df459c7b authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added plotly for metrics

parent 22ce8b65
No related branches found
No related tags found
No related merge requests found
Pipeline #3867 passed
......@@ -77,7 +77,7 @@ def plot_results_noise(directory, noise_results, metric_to_plot, name, width=0.1
def plot_metric_scores(train_scores, test_scores, names, nb_results, metric_name,
file_name,
tag="", train_STDs=None, test_STDs=None):
tag="", train_STDs=None, test_STDs=None, use_plotly=True):
r"""Used to plot and save the score barplot for a specific metric.
Parameters
......@@ -145,10 +145,28 @@ def plot_metric_scores(train_scores, test_scores, names, nb_results, metric_name
test_STDs.reshape((train_scores.shape[0], 1))), axis=1)),
columns=names, index=["Train", "Train STD", "Test", "Test STD"])
dataframe.to_csv(file_name + ".csv")
if use_plotly:
fig = plotly.graph_objs.Figure()
fig.add_trace(plotly.graph_objs.Bar(
name='Train',
x=names, y=train_scores,
error_y=dict(type='data', array=train_STDs),
marker_color="lightgrey",
))
fig.add_trace(plotly.graph_objs.Bar(
name='Test',
x=names, y=test_scores,
error_y=dict(type='data', array=test_STDs),
marker_color="black",
))
fig.update_layout(title=metric_name + "\n" + tag + " scores for each classifier")
plotly.offline.plot(fig, filename=file_name + ".html", auto_open=False)
del fig
def plot_2d(data, classifiers_names, nbClassifiers, nbExamples,
fileName, minSize=10, labels=None,
file_name, minSize=10, labels=None,
width_denominator=2.0, height_denominator=20.0, stats_iter=1,
use_plotly=True, example_ids=None):
r"""Used to generate a 2D plot of the errors.
......@@ -166,7 +184,7 @@ def plot_2d(data, classifiers_names, nbClassifiers, nbExamples,
The number of examples.
nbCopies : int
The number of times the data is copied (classifier wise) in order for the figure to be more readable
fileName : str
file_name : str
The name of the file in which the figure will be saved ("error_analysis_2D.png" will be added at the end)
minSize : int, optinal, default: 10
The minimum width and height of the figure.
......@@ -191,7 +209,7 @@ def plot_2d(data, classifiers_names, nbClassifiers, nbExamples,
cbar = fig.colorbar(cax, ticks=[-100 * stats_iter / 2, 0, stats_iter])
cbar.ax.set_yticklabels(['Unseen', 'Always Wrong', 'Always Right'])
fig.savefig(fileName + "error_analysis_2D.png", bbox_inches="tight", transparent=True)
fig.savefig(file_name + "error_analysis_2D.png", bbox_inches="tight", transparent=True)
plt.close()
### The following part is used to generate an interactive graph.
if use_plotly:
......@@ -215,7 +233,7 @@ def plot_2d(data, classifiers_names, nbClassifiers, nbExamples,
fig.update_xaxes(showticklabels=False, row=row_index+1, col=1)
fig.update_xaxes(showticklabels=True, row=len(label_index_list), col=1)
plotly.offline.plot(fig, filename=fileName + "error_analysis_2D.html", auto_open=False)
plotly.offline.plot(fig, filename=file_name + "error_analysis_2D.html", auto_open=False)
del fig
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment