Something went wrong on our end
Select Git revision
ProgramParameters.hpp
-
Franck Dary authoredFranck Dary authored
plot_PR_curve.py 2.32 KiB
"""Collect the results and plot the graphics"""
import os
import argparse
import numpy as np
import pandas as pd
import glob
import matplotlib as mpl
import matplotlib.pyplot as plt
import utils
def main(arguments):
"""
Plot and save the Precision / Recall curve and compute area under the curve (mAP).
:param argument (argparse) : Parser containing Path and Directory
"""
precision = np.load(glob.glob(os.path.join(arguments.path, 'class_precision.npy'))[0])
recall = np.load(glob.glob(os.path.join(arguments.path, 'class_recall.npy'))[0])
n_class = precision.shape[1]
# Load a colormap
cmap = mpl.colormaps['tab20']
# If more class than color in the colormap, switch with dashed and solid lines
if n_class > 20:
lines = ['dashed','solid'] * round(n_class / 2)
colors = cmap(np.linspace(0, 1, n_class))
else:
lines = ['solid'] * n_class
colors = cmap(np.linspace(0, 1, n_class))
fig = plt.figure(figsize=([15,8]))
# Compute the area under the curve for each class and plot in the figure
for class_index in range(0, n_class):
area_under_curve = np.round(np.trapz(precision[:,class_index], recall), 2)
plt.plot(recall,precision[:,class_index],
label=f'mAP {class_index} = {area_under_curve}',
color=colors[class_index],
linestyle=lines[class_index])
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
fancybox=True, ncol=8)
plt.tight_layout()
if arguments.interactive:
plt.show()
else:
plt.savefig(os.path.join(arguments.path, 'PR_curve_all_class.jpg'))
print(f'Saved in {arguments.path} as PR_curve_all_class.jpg')
return
if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='Collect the results and plot the graphics')
parser.add_argument('path', type=utils.arg_directory,
help='Path of the folder/file that contains the val / test results')
parser.add_argument('--interactive', action='store_const',
help='If in arguments, will show the plot in interactive mode, else will save the plot in path',
const=1, default=None)
args = parser.parse_args()
main(args)