diff --git a/README.md b/README.md index 23dd2546c01826382382d842fb9e84de61890736..2463f0a1c74c338ebd0947c27bfef5adc7d8f75a 100755 --- a/README.md +++ b/README.md @@ -275,6 +275,19 @@ python yolov5/train.py --imgsz <IMG_SIZE> --batch <BATCH_SIZE> --epochs <NB_EPOC ``` <br /> +# Validating the model + +```bash +python yolov5/val.py --task test --imgsz <IMG_SIZE> --data <custom_data.yaml> --weights yolov5/train/<your_train>/weights/best.pt +``` + +To vizualise the Precision / Recall data, use the following command : + +```bash +python plot_PR_curve.py yolov5/runs/val/<your_val> --interactive +``` + + # Detection <br /> diff --git a/plot_PR_curve.py b/plot_PR_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..5d19b82b8bf9d50835ff3c09023324ec8c565a07 --- /dev/null +++ b/plot_PR_curve.py @@ -0,0 +1,64 @@ +"""Collect the results and plot the graphics""" + +import os +import argparse +import numpy as np +import pandas as pd +import glob +import matplotlib as mpl +import matplotlib.pyplot as plt +import utils + +def main(arguments): + """ + Plot and save the Precision / Recall curve and compute area under the curve (mAP). + :param argument (argparse) : Parser containing Path and Directory + """ + + precision = np.load(glob.glob(os.path.join(arguments.path, 'class_precision.npy'))[0]) + recall = np.load(glob.glob(os.path.join(arguments.path, 'class_recall.npy'))[0]) + + n_class = precision.shape[1] + # Load a colormap + cmap = mpl.colormaps['tab20'] + + # If more class than color in the colormap, switch with dashed and solid lines + if n_class > 20: + lines = ['dashed','solid'] * round(n_class / 2) + colors = cmap(np.linspace(0, 1, n_class)) + else: + lines = ['solid'] * n_class + colors = cmap(np.linspace(0, 1, n_class)) + + fig = plt.figure(figsize=([15,8])) + + # Compute the area under the curve for each class and plot in the figure + for class_index in range(0, n_class): + area_under_curve = np.round(np.trapz(precision[:,class_index], recall), 2) + + plt.plot(recall,precision[:,class_index], + label=f'mAP {class_index} = {area_under_curve}', + color=colors[class_index], + linestyle=lines[class_index]) + plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), + fancybox=True, ncol=8) + plt.tight_layout() + + if arguments.interactive: + plt.show() + else: + plt.savefig(os.path.join(arguments.path, 'PR_curve_all_class.jpg')) + print(f'Saved in {arguments.path} as PR_curve_all_class.jpg') + return + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='Collect the results and plot the graphics') + parser.add_argument('path', type=utils.arg_directory, + help='Path of the folder/file that contains the val / test results') + parser.add_argument('--interactive', action='store_const', + help='If in arguments, will show the plot in interactive mode, else will save the plot in path', + const=1, default=None) + args = parser.parse_args() + + main(args)