Skip to content
Snippets Groups Projects
Select Git revision
  • e93e048918f88063449c03ea19801b07d79fbd84
  • master default protected
2 results

chainInsertionTest.ml

Blame
  • Post-process_events.ipynb 25.76 KiB

    Post-process events

    The goal of this notebook is to gather all tools created up to now to post-process the events on database:

    - Merge close events
    - Filter out events of short duration
    

    And also include some new features:

    - Save post-processed events
    - Plot graphics of activity
    

    Loading Python Packages

    '''
    Created on 7 jul. 2023
    
    @author: Raul
    ! Based on lmt analysis of Fabrice de Chaumont
    '''
    %matplotlib inline
    
    import sqlite3
    from pathlib import Path
    import os
    import shutil
    import yaml
    
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    
    from lmtanalysis.FileUtil import getFilesToProcess
    from lmtanalysis.Measure import *
    from lmtanalysis.Event import EventTimeLine
    from lmtanalysis import BuildEventTrain3, BuildEventTrain4, BuildEventTrain2, BuildEventFollowZone, BuildEventRear5, \
        BuildEventFloorSniffing, BuildEventSocialApproach, BuildEventSocialEscape, BuildEventApproachContact, \
        BuildEventOralOralContact, BuildEventApproachRear, BuildEventGroup2, BuildEventGroup3, BuildEventGroup4, \
        BuildEventOralGenitalContact, BuildEventStop, BuildEventWaterPoint, BuildEventHuddling, \
        BuildEventMove, BuildEventGroup3MakeBreak, BuildEventGroup4MakeBreak, BuildEventSideBySide, \
        BuildEventSideBySideOpposite, BuildEventWallJump, BuildEventSAP, BuildEventOralSideSequence,\
        BuildEventNest4, BuildEventNest3, BuildEventGetAway, BuildEventCenterPeripheryLocation, \
        BuildEventOtherContact, BuildEventPassiveAnogenitalSniff
    from lmtanalysis.split_database import *
    from lmtanalysis.Util import getNumberOfFrames
    id_animal_color = [ "black", "red", "green", "purple", "orange"]
    filtered_events_dir = "filtered_events_data"
    
    def plotGraphics(x_data, y_data, **kwargs):
        line = kwargs.get("line_type", '-o')
        c = kwargs.get("color", "black")
        x_limit = kwargs.get("x_limit", 0)
        y_limit = kwargs.get("y_limit", 10)
        x_label = kwargs.get("x_label", "Time [minutes]")
        y_label = kwargs.get("y_label", "Time [minutes]")
        title = kwargs.get("title", "Time spend on action by mice ID")
        filename = kwargs.get("filename", "time_action_ID.png")
        csv_filename = kwargs.get("csv_filename", "time_spent_per_bin.csv")
    
        plt.rcParams['figure.figsize'] = [15, 5]
        plt.plot(x_data, y_data, line, color = c)
        plt.xlim(0, x_limit)
    #     plt.ylim(0, y_limit)
        if len(x_data) > 10:
            plt.xticks(x_data[::len(x_data)//7])
        else:
            plt.xticks(x_data)
        plt.xlabel(x_label)
        plt.ylabel(y_label)
        plt.title(title)
        plt.savefig(filename)
        plt.close()
    
        df = pd.DataFrame(data = np.array([x_data, y_data]).T, columns=["Time", "Count"])
        df.to_csv(csv_filename)
    
    def computeTimeSpent(eventTL, **kwargs):
        t_min = kwargs.get("t_min", 0)
        t_max = kwargs.get("t_max", oneHour) - 1
        timestep = kwargs.get("timestep", 10 * oneMinute)
        scale = kwargs.get("events_scale", oneMinute)
        density_events = eventTL.getDensityEventInTimeBin(tmin=t_min, tmax=t_max, binSize=timestep, return_density=False)
        density_events = [e / scale for e in density_events]
        return density_events
    
    
    def computeNumberEvents(eventTL, **kwargs):
        t_min = kwargs.get("t_min", 0)
        t_max = kwargs.get("t_max", oneHour) - 1
        timestep = kwargs.get("timestep", 10 * oneMinute)
    
        count_events = []
        for timebin in range(t_min, t_max, timestep):
            n = eventTL.getNumberOfEvent(timebin, timebin + timestep)
            count_events.append(n)
        return count_events
    
    def filterEventTimeLine(connection, **kwargs):
        action = kwargs.get("action", None)
        if action is None:
            return
        idA = kwargs.get("idA", None)
        idB = kwargs.get("idB", None)
        idC = kwargs.get("idC", None)
        idD = kwargs.get("idD", None)
        t_min = kwargs.get("t_min", 0)
        t_max = kwargs.get("t_max", oneHour) - 1
        minimum_event_particle_lenght = kwargs.get("minimum_event_particle_length", 0)
        dilation_factor = kwargs.get('merge_events_length', 5)
        minimum_event_length = kwargs.get("minimum_event_length", 30)
    
        eventTL = EventTimeLine(connection, action, idA= idA, idB= idB, idC= idC, idD= idD, minFrame= t_min, maxFrame= t_max)
        eventTL.removeEventsBelowLength(minimum_event_particle_lenght)
        eventTL.closeEvents(dilation_factor)
        eventTL.removeEventsBelowLength(minimum_event_length)
    
        return eventTL
    
    def saveTimeLine2CSV(eventTL, csv_events_file):
        header = ["ID_A",  "ID_B", "ID_C", "ID_D", "Start frame", "End frame", "Duration (in frame)", "Metadata"]
        if len(eventTL.eventList) > 0:
            data = []
            for event in eventTL.eventList:
                data.append([eventTL.idA, eventTL.idB, eventTL.idC, eventTL.idD, event.startFrame, event.endFrame, event.duration(), event.metadata])
    
            df = pd.DataFrame(data = np.array(data), columns=header)
            df.to_csv(csv_events_file)
            print("Saved to :", csv_events_file)
        else:
            print("Event Length: 0")
            print("CSV file NOT Saved")
    scripts_events_dictionary = {
    "Drinking": filterEventTimeLine,
    "Eating": filterEventTimeLine,
    "Approach": filterEventTimeLine,
    "Approach contact": BuildEventApproachContact.filterEventTimeLine,
    "Approach rear": BuildEventApproachRear.filterEventTimeLine,
    "Break contact": filterEventTimeLine,
    "Contact": filterEventTimeLine,
    # "Detection": None,
    "Floor sniffing": BuildEventFloorSniffing.filterEventTimeLine,
    "FollowZone": BuildEventFollowZone.filterEventTimeLine,
    "Get away": BuildEventGetAway.filterEventTimeLine,
    "Group 3 break": BuildEventGroup3MakeBreak.filterEventTimeLine,
    "Group 3 make": BuildEventGroup3MakeBreak.filterEventTimeLine,
    "Group 4 break": BuildEventGroup4MakeBreak.filterEventTimeLine,
    "Group 4 make": BuildEventGroup4MakeBreak.filterEventTimeLine,
    "Group2": BuildEventGroup2.filterEventTimeLine,
    "Group3": BuildEventGroup3.filterEventTimeLine,
    "Group4": BuildEventGroup4.filterEventTimeLine,
    # "Head detected": None,
    "Huddling": BuildEventHuddling.filterEventTimeLine,
    "Look down": filterEventTimeLine,
    "Look up": filterEventTimeLine,
    # "MACHINE LEARNING ASSOCIATION": None,
    "Move in contact": BuildEventMove.filterEventTimeLine,
    "Move isolated": BuildEventMove.filterEventTimeLine,
    "Nest4": BuildEventNest4.filterEventTimeLine,
    "Nest3": BuildEventNest3.filterEventTimeLine,
    "Oral-genital Contact": BuildEventOralGenitalContact.filterEventTimeLine,
    "Oral-oral Contact": BuildEventOralOralContact.filterEventTimeLine,
    "Other contact": BuildEventOtherContact.filterEventTimeLine,
    "Passive oral-genital Contact": BuildEventPassiveAnogenitalSniff.filterEventTimeLine,
    # "RFID ASSIGN ANONYMOUS TRACK": None,
    # "RFID MATCH": None,
    # "RFID MISMATCH": None,
    "Rear in contact": BuildEventRear5.filterEventTimeLine,
    "Rear isolated": BuildEventRear5.filterEventTimeLine,
    # "Rearing": None,
    "SAP": BuildEventSAP.filterEventTimeLine,
    "Side by side Contact": BuildEventSideBySide.filterEventTimeLine,
    "Side by side Contact, opposite way": BuildEventSideBySideOpposite.filterEventTimeLine,
    "Social approach": BuildEventSocialApproach.filterEventTimeLine,
    "Social escape": BuildEventSocialEscape.filterEventTimeLine,
    #"Stop": None,
    "Stop in contact": BuildEventStop.filterEventTimeLine,
    "Stop isolated": BuildEventStop.filterEventTimeLine,
    "Train2": BuildEventTrain2.filterEventTimeLine,
    "Train3": BuildEventTrain3.filterEventTimeLine,
    "Train4": BuildEventTrain4.filterEventTimeLine,
    "WallJump": BuildEventWallJump.filterEventTimeLine,
    "Water Stop": BuildEventWaterPoint.filterEventTimeLine,
    "Water Zone": BuildEventWaterPoint.filterEventTimeLine,
    "seq oral geni - oral oral": BuildEventOralSideSequence.filterEventTimeLine,
    "seq oral oral - oral genital": BuildEventOralSideSequence.filterEventTimeLine,
    "Periphery Zone": BuildEventCenterPeripheryLocation.filterEventTimeLine,
    "Center Zone": BuildEventCenterPeripheryLocation.filterEventTimeLine,
    }

    Ask file to process to the user with a dialog window:

    • note 1: You can run this step only once, and then keep processing the file(s) with the next cells.
    • note 2: The file window can be hidden by other windows.
    #ask the user for database to process
    print( "Select file name in window")
    files = getFilesToProcess()

    Parameters

    Change this cell to choose:

    • The time slot to be filtered;
    • The timestep in which events will be processed;
    • What you want to save.

    Actions being post processed and its parameters can be found in 'actions_pproc_params.yaml'

    # Event parameters
    minf = 0
    maxf = getNumberOfFrames(files[0])
    
    # Plotting parameters
    timestep = 10 * oneMinute
    
    # Saving parameters
    save_graphics = True
    save_csv_file = True
    save_sqlite_file = False
    for file in files:
        # connect to database
        connection = sqlite3.connect(file)
        sqlite_dir = Path(file).parent
        if save_sqlite_file:
            filename, file_extension = os.path.splitext(file)
            target_db = filename + '_filtered' + file_extension
            start = 0
            length = getNumberOfFrames(file)
            extractEventTable(connection, target_db, start, length)
            target_connection = sqlite3.connect(target_db)
        logs_dir = sqlite_dir / filtered_events_dir
        logs_dir.mkdir(exist_ok=True)
        nest4_TL = None
        with open('actions_pproc_params.yaml', 'r') as f:
            actions_dictionary = yaml.load(f, Loader=yaml.SafeLoader)
    
        for name, params in actions_dictionary.items():
            if scripts_events_dictionary[name] == None:
                continue
            for id1 in range(1, 5):
                for id2 in range(1, 5):
                    if ((params['n_ids'] > 1) and (id1 == id2)) or (params['n_ids'] <= 1 and id2 > 1):
                        continue
                    for id3 in range(1, 5):
                        if ((params['n_ids'] > 2) and ((id1 == id3) or (id2 == id3))) or (params['n_ids'] <= 2 and id3 > 1):
                            continue
                        idA, idB, idC = id1, id2, id3
                        if params.get('n_ids') == 0:
                            idC = None
                            idB = None
                            idA = None
                        elif params.get('n_ids') == 1:
                            idC = None
                            idB = None
                        elif params.get('n_ids') == 2:
                            idC = None
                        dilation_factor = params['merge_events_length']
                        minimum_event_length = params['minimum_event_length']
                        eventTimeLine = scripts_events_dictionary[name](connection,
                                                            t_min = minf,
                                                            t_max = maxf,
                                                            idA = idA,
                                                            idB = idB,
                                                            idC = idC,
                                                            idD = None,
                                                            nest4_TL = nest4_TL,
                                                            action=name,
                                                            **params)
                        if name == "Nest4":
                            nest4_TL = eventTimeLine
    
                        if save_graphics:
                            # Generate plot of time spent in action (in minutes)
                            density_events = computeTimeSpent(eventTimeLine,
                                                t_min = minf,
                                                t_max = maxf,
                                                timestep = timestep,
                                                events_scale = oneMinute)
                            bins = timestep / oneMinute * np.arange(1, len(density_events) + 1)
    
                            if params.get('n_ids') == 1:
                                title = f"Time spend on {name} by mice {id1}"
                                filename = f"time_{name}_{id1}.png"
                                csv_filename = f"time_{name}_{id1}_per_bin.csv"
                            elif params.get('n_ids') == 2:
                                title = f"Time spend on {name} by mice {id1}-{id2}"
                                filename = f"time_{name}_{id1}-{id2}.png"
                                csv_filename = f"time_{name}_{id1}-{id2}_per_bin.csv"
                            elif params.get('n_ids') == 3:
                                title = f"Time spend on {name} by mice {id1}-{id2}-{id3}"
                                filename = f"time_{name}_{id1}-{id2}-{id3}.png"
                                csv_filename = f"time_{name}_{id1}-{id2}-{id3}_per_bin.csv"
                            else: # Altough it is 4, we repeat it for everyone
                                title = f"Time spend on {name} by mice {id1}"
                                filename = f"time_{name}_{id1}.png"
                                csv_filename = f"time_{name}_{id1}_per_bin.csv"
    
                            plotGraphics(bins, density_events,
                                                title = title,
                                                filename = logs_dir / filename,
                                                csv_filename = logs_dir / csv_filename,
                                                x_limit = (maxf + timestep)/oneMinute,
                                                color = id_animal_color[id1])
    
                            # Generate plot of number of actions
                            count_events = computeNumberEvents(eventTimeLine,
                                                    t_min = minf,
                                                    t_max = maxf,
                                                    timestep = timestep)
                            bins = timestep / oneMinute * np.arange(1, len(count_events) + 1)
                            if params.get('n_ids') == 1:
                                title = f"Number of {name} by mice {id1}"
                                filename = f"number_{name}_{id1}.png"
                                csv_filename = f"number_{name}_{id1}_per_bin.csv"
                            elif params.get('n_ids') == 2:
                                title = f"Number of {name} by mice {id1}-{id2}"
                                filename = f"number_{name}_{id1}-{id2}.png"
                                csv_filename = f"number_{name}_{id1}-{id2}_per_bin.csv"
                            elif params.get('n_ids') == 3:
                                title = f"Number of {name} by mice {id1}-{id2}-{id3}"
                                filename = f"number_{name}_{id1}-{id2}-{id3}.png"
                                csv_filename = f"number_{name}_{id1}-{id2}-{id3}_per_bin.csv"
                            else: # Altough it is 4, we repeat it for everyone
                                title = f"Number of {name} by mice {id1}"
                                filename = f"number_{name}_{id1}.png"
                                csv_filename = f"number_{name}_{id1}_per_bin.csv"
                            plotGraphics(bins, count_events,
                                                    title = title,
                                                    filename = logs_dir / filename,
                                                    csv_filename = logs_dir / csv_filename,
                                                    y_label = "Number of events",
                                                    x_limit = (maxf + timestep)/oneMinute,
                                                    color = id_animal_color[id1])
                        if save_csv_file:
                            # Save to csv
                            if params.get('n_ids') == 1:
                                csv_events_file = logs_dir / f"Events_{name}_ID_{id1}.csv"
                            elif params.get('n_ids') == 2:
                                csv_events_file = logs_dir / f"Events_{name}_ID_{id1}-{id2}.csv"
                            elif params.get('n_ids') == 3:
                                csv_events_file = logs_dir / f"Events_{name}_ID_{id1}-{id2}-{id3}.csv"
                            else: # Altough it is 4, we repeat it for everyone
                                csv_events_file = logs_dir / f"Events_{name}_ID_{id1}.csv"
    
                            saveTimeLine2CSV(eventTimeLine, csv_events_file)
    
                        if save_sqlite_file:
                            eventTimeLine.deleteEventTimeLineInBase(target_connection)
                            eventTimeLine.saveTimeLine(target_connection)
    
        connection.close()
        if save_sqlite_file:
            target_connection.close()
    print("*** ALL JOBS DONE ***")
    

    Plotting 2 events at the same time

    This part of the notebook is meant to plot 2 events for the same period of time. You can change between:

    • Different actions of the same individual;
    • Same action for different individuals;
    • Different actions for different individuals
    # Resetting parameters
    action_1 = "Rear isolated"
    action_2 = "SAP"
    id1_1 = 4
    id2_1 = None
    id3_1 = None
    id4_1 = None
    id1_2 = 4
    id2_2 = None
    id3_2 = None
    id4_2 = None
    minf = 0
    maxf = 3*oneHour
    
    # Filtering parameters
    dilation_factor = 5
    minimum_event_length = 30
    
    # Plotting parameters
    timestep = 5 * oneMinute
    for file in files:
        # connect to database
        connection = sqlite3.connect(file)
        sqlite_dir = Path(file).parent
    
        eventTimeLine_1 = filterEventTimeLine(connection,
                                           action=action_1,
                                           t_min = minf,
                                           t_max = maxf,
                                           idA = id1_1, idB = id2_1, idC = id3_1, idD = id4_1,
                                           merge_events_length = dilation_factor,
                                           minimum_event_length = minimum_event_length)
    
        eventTimeLine_2 = filterEventTimeLine(connection,
                                           action=action_2,
                                           t_min = minf,
                                           t_max = maxf,
                                           idA = id1_2, idB = id2_2, idC = id3_2, idD = id4_2,
                                           merge_events_length = dilation_factor,
                                           minimum_event_length = minimum_event_length)
    
        density_events_1 = computeTimeSpent(eventTimeLine_1,
                         t_min = minf,
                         t_max = maxf,
                         timestep = timestep,
                         events_scale = oneMinute)
        density_events_2 = computeTimeSpent(eventTimeLine_2,
                         t_min = minf,
                         t_max = maxf,
                         timestep = timestep,
                         events_scale = oneMinute)
        bins_1 = timestep / oneMinute * np.arange(1, len(density_events_1) + 1)
        bins_2 = timestep / oneMinute * np.arange(1, len(density_events_2) + 1)
    
        plt.rcParams['figure.figsize'] = [15, 5]
        plt.plot(bins_1, density_events_1, '-o', color = "black")
        plt.plot(bins_2, density_events_2, '--o', color = "gray")
        plt.xlim(0, (maxf + timestep)/oneMinute)
        if len(bins_1) > 10:
            plt.xticks(bins_1[::len(bins_1)//7])
        else:
            plt.xticks(bins_1)
        plt.xlabel("Time [minutes]")
        plt.ylabel("Time [minutes]")
        plt.title(f"Time on actions by mice {id1_1}")
        plt.legend([action_1+f" {id1_1}", action_2+f" {id1_2}"])
        #plt.savefig(filename)
        plt.show()
        plt.close()
    
        count_events_1 = computeNumberEvents(eventTimeLine_1,
                             t_min = minf,
                             t_max = maxf,
                             timestep = timestep)
        count_events_2 = computeNumberEvents(eventTimeLine_2,
                             t_min = minf,
                             t_max = maxf,
                             timestep = timestep)
        plt.plot(bins_1, count_events_1, '-o', color = "black")
        plt.plot(bins_2, count_events_2, '--o', color = "gray")
        plt.xlim(0, (maxf + timestep)/oneMinute)
        if len(bins_1) > 10:
            plt.xticks(bins_1[::len(bins_1)//7])
        else:
            plt.xticks(bins_1)
        plt.xlabel("Time [minutes]")
        plt.ylabel("Number of events")
        plt.title(f"Number of actions by mice {id1_1}")
        plt.legend([action_1+f" {id1_1}", action_2+f" {id1_2}"])
        #plt.savefig(filename)
        plt.show()
        plt.close()