diff --git a/overlay_annot_detect.py b/overlay_annot_detect.py new file mode 100644 index 0000000000000000000000000000000000000000..3612db63520d3e904dde28a75d035ef7a3312ab9 --- /dev/null +++ b/overlay_annot_detect.py @@ -0,0 +1,81 @@ +import numpy as np +import matplotlib.pyplot as plt +import cv2 +import glob +import os +import argparse +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +from matplotlib import cm +import pandas as pd + +# Define constants for colors +colors = cm.get_cmap('Blues', 50) +colors_yolo = cm.get_cmap('Greens', 50) + + +def arg_directory(path): + if os.path.isdir(path): + return path + else: + raise argparse.ArgumentTypeError(f'`{path}` is not a valid path') + +def overlay_annotations(image_path, annotation_path, detection_path, output_directory): + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + H, W = image.shape[0], image.shape[1] + base_name = os.path.splitext(os.path.basename(image_path))[0] + + try: + shape = pd.read_csv(annotation_path, sep=' ', names=['label', 'x', 'y', 'w', 'h']) + shape_yolo = pd.read_csv(detection_path, sep=' ', names=['label', 'x', 'y', 'w', 'h']) + except Exception: + return + + for shape_df, colors_palette in zip([shape, shape_yolo], [colors, colors_yolo]): + for z in range(len(shape_df)): + x, y, w, h = shape_df.iloc[z][['x', 'y', 'w', 'h']] * [W, H, W, H] + + # Calculate rectangle coordinates + shape1 = (int(x - 0.5 * w), int(y + 0.5 * h)) + shape4 = (int(x + 0.5 * w), int(y - 0.5 * h)) + + # Calculate text coordinates + shp1 = (shape1[0], shape1[1] + 20) + shp4 = (shape4[0], shape4[1]) + text_shape = (shp1[0], shp1[1] - 5) + + label = str(shape_df.label.iloc[z]) + + # Draw rectangle and text + cv2.rectangle(image, pt1=shape1, pt2=shape4, color=colors_palette[shape_df.label.iloc[z]], thickness=1) + cv2.rectangle(image, pt1=shp1, pt2=shp4, color=colors_palette[shape_df.label.iloc[z]], thickness=-1) + cv2.putText(image, label, text_shape, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) + + output_path = os.path.join(output_directory, f'{base_name}.jpg') + plt.imshow(image, cmap='jet') + plt.title('Blues : ANNOTATIONS; Greens : YOLO DETECTIONS', loc = 'center') + plt.subplots_adjust(top=1, bottom=0, left=0, right=1) + plt.savefig(output_path) + plt.close() + +if __name__ == '__main__': + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='TODO') + parser.add_argument('-p', '--path_to_data', type=str, help='Path of the folder that contains the .jpg files (*/set/images/)', required=True) + parser.add_argument('-s', '--detection', type=str, help='Path the folder containing the .txt detection (*/exp/labels)', required=True) + parser.add_argument('-d', '--directory', type=arg_directory, help='Directory to which the overlayed images will be stored', required=True) + parser.add_argument('-a', '--annotation', type=str, help='Path the .txt containing the annotation (*/train/labels/)', required=True) + args = parser.parse_args() + + path = args.path_to_data + directory = args.directory + detection = args.detection + annotation = args.annotation + + image_files = glob.glob(os.path.join(path, '*', '*.jpg')) + + for image_path in image_files: + base_name = os.path.splitext(os.path.basename(image_path))[0] + annotation_path = os.path.join(annotation, f'{base_name}.txt') + detection_path = os.path.join(detection, f'{base_name}.txt') + + overlay_annotations(image_path, annotation_path, detection_path, directory)