Skip to content
Snippets Groups Projects
Select Git revision
  • 6fb992d6eb4eb2f9777fb4ce628872ed7b17b5a1
  • main default protected
  • development
3 results

Add_markers_script.py

Blame
  • user avatar
    raul-silva authored
    Simply removing a level of directory.
    6fb992d6
    History
    Add_markers_script.py 6.86 KiB
    ###################################################################################################
    # Add Markers Script
    # Script version of the Add_markers notebook to run in command line
    # Author: Raul Silva
    # Marseille 28-11-2024
    ###################################################################################################
    import argparse
    from lxml import etree
    from pathlib import Path
    import re
    import sqlite3
    import time
    import zlib
    import cv2
    import numpy as np
    import pandas as pd
    from skimage.morphology import binary_erosion
    
    
    # Variables
    VIDEO_LENGTH = 10 * 60 * 30 # Ten minutes at 30 frames/s (as registered on database)
    VIDEO_SUBSAMPLING = 2
    COLOR = [(0, 0, 0), (0, 0, 255), (0, 255, 0), (255, 0, 255), (0, 255, 255),
             (255, 0, 0), (255, 255, 0), (0, 128, 255), (128, 255, 0)]
    START_HISTORY_FACTOR = 60
    # ADD OBJECTS MARKERS
    add_objects_markers = False
    
    # Functions
    def findInitialFrame(string):
        match = re.search("t[0-9]*\.", string)
        pos = match.span()
        return int(string[pos[0]+1:pos[1]-1])
    
    
    def generateColoredFrame(dims):
        cf = []
        for c in COLOR:
            cf.append(np.ones(dims) * np.array([[c]]))
        return cf
    
    
    def unzip(maskDataZipped):
        # re fill 0 and put space instead of : separator
        if maskDataZipped == None:
            return
        s = maskDataZipped.split(":")
        s2 = ""
        for value in s:
            if ( len(value) == 1 ):
                s2+="0"
            s2+=value+" "
    
        b = bytearray.fromhex(s2)
        uncompressed= zlib.decompress(b)
        return uncompressed
    
    
    def getMetadata(line):
        try:
            id = int(line["ANIMALID"])
        except:
            id = 0
        data = line["DATA"]
    
        tree = etree.fromstring(data)
        for user in tree.xpath("/root/ROI/boundsX"):
            x = int(user.text)
        for user in tree.xpath("/root/ROI/boundsY"):
            y = int(user.text)
        for user in tree.xpath("/root/ROI/boundsW"):
            w = int(user.text)
        for user in tree.xpath("/root/ROI/boundsH"):
            h = int(user.text)
        for user in tree.xpath("/root/ROI/boolMaskData"):
            bool_mask_data = user.text
    
        mask = list(map(int, unzip(bool_mask_data)))
        return id, [x, y, w, h, mask]
    
    
    def generateAnimalContours(frame_dims, x, y, w, h, animal_mask):
        animal_mask = np.array(animal_mask, dtype = np.uint8).reshape((h, w))
        animal_mask = np.pad(animal_mask, 1)
        contour = animal_mask - binary_erosion(animal_mask, np.ones((5,5)))
    
        mask = np.zeros(frame_dims, dtype = np.float32)
        mask[y:y + h, x:x + w, 0] = contour[1:-1, 1:-1]
        mask[y:y + h, x:x + w, 1] = contour[1:-1, 1:-1]
        mask[y:y + h, x:x + w, 2] = contour[1:-1, 1:-1]
        return mask
    
    
    def getHistoryOld(df, id, t, mask, history_factor = 300, w=2):
        t_start = max(0, t - history_factor)
        filter = ((df["FRAMENUMBER"] >= t_start) & (df["FRAMENUMBER"] <= t) & (df["ANIMALID"] == id))
        for _, line in df[filter].iterrows():
            x = int(line["MASS_X"])
            y = int(line["MASS_Y"])
            mask[y-w:y+w, x-w:x+w] = 1
        return mask
    
    def getHistory(df, id, t, history_factor = 300):
        t_start = max(0, t - history_factor)
        filter = ((df["FRAMENUMBER"] >= t_start) & (df["FRAMENUMBER"] <= t) & (df["ANIMALID"] == id))
        return df[filter][["MASS_X", 'MASS_Y']].astype('int').values.tolist()
    
    def addHistory(history, mask, w=2):
        pts = np.array(history).reshape(-1, 1, 2)
        mask = cv2.polylines(mask, [pts], False, (1,1,1), thickness=w)
        return mask
    
    def addObject(frame_dims, obj):
        X = round(obj["X_CENTER"])
        Y = round(obj["Y_CENTER"])
        r = round(obj["RADIUS"])
        mask = np.zeros(frame_dims, dtype = np.float32)
        mask = cv2.circle(mask, (X,Y), r, (1, 1, 1), 1)
        return mask
    
    def addMarkers(video_path, output_video_name):
        t1 = time.time()
    
        # Extract time
        startframe = findInitialFrame(video_path.name)
        # Open database
        connection = sqlite3.connect(database)
        query = "SELECT * FROM DETECTION WHERE FRAMENUMBER>={0} AND FRAMENUMBER<={1} ORDER BY FRAMENUMBER".format(startframe - START_HISTORY_FACTOR,
                                                                                            startframe + VIDEO_LENGTH)
        dataframe = pd.read_sql(query, connection)
    
        if add_objects_markers:
            query = "SELECT * FROM OBJECTS"
            objs_df = pd.read_sql(query, connection)
    
        # Open video
        input_capture = cv2.VideoCapture(video_path)
        height = int(input_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
        width = int(input_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
    
        frame_size = (width, height)
        output_capture = cv2.VideoWriter(output_video_name, cv2.VideoWriter_fourcc(*'mp4v'), 15, frame_size)
    
        if (input_capture.isOpened()== False):
            print("Error opening video stream or file")
    
        # Drawing on the masks
        t = startframe if startframe != 3 else 9 # First video frames are slightly late
        animals = np.unique(dataframe["ANIMALID"])
        animal_history = []
        for animal in range(5):
            animal_history.append(getHistory(dataframe, animal, t))
        dims = (height, width, 3)
        colored_frame = generateColoredFrame(dims)
        while(input_capture.isOpened()):
            ret, frame = input_capture.read()
            if ret == True:
                filter = dataframe["FRAMENUMBER"] == t
                for _, line in dataframe[filter].iterrows():
                    try:
                        id, [x, y, w, h, animal_mask] = getMetadata(line)
                    except:
                        continue
                    mask = generateAnimalContours(dims, x, y, w, h, animal_mask)
                    animal_history[id].append([int(line["MASS_X"]), int(line["MASS_Y"])])
                    animal_history[id].pop(0) if len(animal_history[id]) >= START_HISTORY_FACTOR//2 else 0
                    mask = addHistory(animal_history[id], mask)
                    frame = ((1 - mask) * frame + mask * colored_frame[id]).astype(np.uint8)
    
                if add_objects_markers:
                    for i, line in objs_df.iterrows():
                        id = i + 5
                        mask = addObject(dims, line)
                        frame = ((1 - mask) * frame + mask * colored_frame[id]).astype(np.uint8)
    
                output_capture.write(frame)
            else:
                break
    
            t += VIDEO_SUBSAMPLING
    
        input_capture.release()
        output_capture.release()
    
        print("Ended in :", time.time()-t1)
    
    
    
    if __name__ == "__main__":
    
        argparser = argparse.ArgumentParser()
        argparser.add_argument("-vid" , "--video_path", required=True,
                               help="Path to video to be converted. Only one at a time. Use batch files \
                                to run over multiple videos")
    
        args = argparser.parse_args()
        video_path = Path(args.video_path)
        directory = video_path.parent
        database = [u.absolute().__str__() for u in directory.iterdir() if 'sqlite' in u.name][0]
    
        output_directory = directory / "marked_videos"
        output_directory.mkdir(exist_ok=True)
        output_video_name = (output_directory / video_path.name).__str__()
    
    
        addMarkers(video_path, output_video_name)