Skip to content
Snippets Groups Projects
Select Git revision
  • 43f370e2e8ccff1a955a01705c17dfd7eacb67a6
  • master default
  • object
  • develop protected
  • private_algos
  • cuisine
  • SMOTE
  • revert-76c4cca5
  • archive protected
  • no_graphviz
  • 0.0.1
11 results

test_configuration.py

Blame
  • plot_PR_curve.py 2.32 KiB
    """Collect the results and plot the graphics""" 
    
    import os 
    import argparse 
    import numpy as np
    import pandas as pd
    import glob
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    import utils
    
    def main(arguments):
        """
        Plot and save the Precision / Recall curve and compute area under the curve (mAP).
        :param argument (argparse) : Parser containing Path and Directory
        """
    
        precision = np.load(glob.glob(os.path.join(arguments.path, 'class_precision.npy'))[0])
        recall = np.load(glob.glob(os.path.join(arguments.path, 'class_recall.npy'))[0])
    
        n_class = precision.shape[1]
        # Load a colormap
        cmap = mpl.colormaps['tab20']
    
        # If more class than color in the colormap, switch with dashed and solid lines
        if n_class > 20:
            lines = ['dashed','solid'] * round(n_class / 2)
            colors = cmap(np.linspace(0, 1, n_class))
        else: 
            lines = ['solid'] * n_class
            colors = cmap(np.linspace(0, 1, n_class))
    
        fig = plt.figure(figsize=([15,8]))
    
        # Compute the area under the curve for each class and plot in the figure
        for class_index in range(0, n_class):
            area_under_curve = np.round(np.trapz(precision[:,class_index], recall), 2)
    
            plt.plot(recall,precision[:,class_index],
                     label=f'mAP {class_index} = {area_under_curve}',
                     color=colors[class_index],
                     linestyle=lines[class_index])
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
              fancybox=True, ncol=8)
        plt.tight_layout()
    
        if arguments.interactive:
            plt.show()
        else:
            plt.savefig(os.path.join(arguments.path, 'PR_curve_all_class.jpg'))
            print(f'Saved in {arguments.path} as PR_curve_all_class.jpg')
        return
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser(
            formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='Collect the results and plot the graphics')
        parser.add_argument('path', type=utils.arg_directory,
                            help='Path of the folder/file that contains the val / test results')
        parser.add_argument('--interactive', action='store_const',
                            help='If in arguments, will show the plot in interactive mode, else will save the plot in path', 
                            const=1, default=None)
        args = parser.parse_args()
    
        main(args)