Skip to content
Snippets Groups Projects
Commit f9fb75bf authored by Stephane Chavin's avatar Stephane Chavin
Browse files

add the PR curve plot

parent 76e3aea3
Branches
No related tags found
No related merge requests found
......@@ -275,6 +275,19 @@ python yolov5/train.py --imgsz <IMG_SIZE> --batch <BATCH_SIZE> --epochs <NB_EPOC
```
<br />
# Validating the model
```bash
python yolov5/val.py --task test --imgsz <IMG_SIZE> --data <custom_data.yaml> --weights yolov5/train/<your_train>/weights/best.pt
```
To vizualise the Precision / Recall data, use the following command :
```bash
python plot_PR_curve.py yolov5/runs/val/<your_val> --interactive
```
# Detection
<br />
......
"""Collect the results and plot the graphics"""
import os
import argparse
import numpy as np
import pandas as pd
import glob
import matplotlib as mpl
import matplotlib.pyplot as plt
import utils
def main(arguments):
"""
Plot and save the Precision / Recall curve and compute area under the curve (mAP).
:param argument (argparse) : Parser containing Path and Directory
"""
precision = np.load(glob.glob(os.path.join(arguments.path, 'class_precision.npy'))[0])
recall = np.load(glob.glob(os.path.join(arguments.path, 'class_recall.npy'))[0])
n_class = precision.shape[1]
# Load a colormap
cmap = mpl.colormaps['tab20']
# If more class than color in the colormap, switch with dashed and solid lines
if n_class > 20:
lines = ['dashed','solid'] * round(n_class / 2)
colors = cmap(np.linspace(0, 1, n_class))
else:
lines = ['solid'] * n_class
colors = cmap(np.linspace(0, 1, n_class))
fig = plt.figure(figsize=([15,8]))
# Compute the area under the curve for each class and plot in the figure
for class_index in range(0, n_class):
area_under_curve = np.round(np.trapz(precision[:,class_index], recall), 2)
plt.plot(recall,precision[:,class_index],
label=f'mAP {class_index} = {area_under_curve}',
color=colors[class_index],
linestyle=lines[class_index])
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
fancybox=True, ncol=8)
plt.tight_layout()
if arguments.interactive:
plt.show()
else:
plt.savefig(os.path.join(arguments.path, 'PR_curve_all_class.jpg'))
print(f'Saved in {arguments.path} as PR_curve_all_class.jpg')
return
if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='Collect the results and plot the graphics')
parser.add_argument('path', type=utils.arg_directory,
help='Path of the folder/file that contains the val / test results')
parser.add_argument('--interactive', action='store_const',
help='If in arguments, will show the plot in interactive mode, else will save the plot in path',
const=1, default=None)
args = parser.parse_args()
main(args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment