Skip to content
Snippets Groups Projects
Select Git revision
2 results Searching

setup.py

Blame
  • utils.py 20.29 KiB
    """Define all the function that are used in the repository"""
    
    import argparse
    import glob
    import shutil
    import os
    import base64
    import json
    from datetime import date
    from pathlib import Path
    import librosa
    
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    import xarray as xr
    from scipy import signal
    from tqdm import tqdm
    
    
    def arg_directory(path):
        """
        Check if the path given in args is a real directory or not.
        :param path (str): The path to a folder.
        :return str: Return the path if correct, error if doesn't exist.
        """
        if os.path.isdir(path):
            return path
        raise argparse.ArgumentTypeError(f'{path} is not a valid path')
    
    
    def copy_files_to_directory(file_list, source, directory, suffix):
        """
        Copy files from a directory to another one
        :param file_list (str): List of the file to copy.
        :param source_dir (str): Directory of the original files.
        :param directory (str): Directory to copy the new files.
        :param suffix (str): Suffix of the files
        """
        for filename in file_list:
            source_file = os.path.join(source, f'{filename}.{suffix}')
            destination_path = os.path.join(directory, f'{filename}.{suffix}')
            shutil.copy2(source_file, destination_path)
    
    
    def create_directory(directory):
        """
        Create a directory if not exists.
        :param directory (str): Directory to create.
        """
        # Check if directory exists
        if not os.path.exists(directory):
            os.mkdir(directory)
            print(f'`{directory}` has been created')
    
    
    def signal_processing(sig, rf, fs, high=None, low=None):
        """
        Resample the signal and apply high pass and low pass filter.
    
        :param sig (array): Signal.
        :param rf (int): Resampling Frequency.
        :param fs (int): Original Sampling Frequency.
        :param high (int): High pass filter value (default None).
        :param low (int): Low pass filter value (default None).
        :return array: Processed signal.
        """
        # Check if resampling frequency is different than sampling frequency
        if not rf:
            rf = fs
        if rf != fs:
            sig = signal.resample(sig, int(len(sig) * rf / fs)
                                  )  # Resample the signal
    
        # Apply high pass filter if specified
        if high:
            # Create high pass filter
            high_pass = signal.butter(2, high / (rf / 2), 'hp', output='sos')
            sig = signal.sosfilt(high_pass, sig)  # Apply high pass filter
    
        # Apply low pass filter if specified
        if low:
            # Create low pass filter
            low_pass = signal.butter(1, low / (rf / 2), 'lp', output='sos')
            sig = signal.sosfilt(low_pass, sig)  # Apply low pass filter
    
        return sig
    
    
    def create_spectrogram(sig, directory, names, cmap, window_size=1024, overlap=.5,):
        """
        Create a spectrogram STFT with hanning window and save it into a directory
    
        :param sig (array): Signal to process.
        :param window_size (int): Number of sample / STFT window.
        :param overlap (float): Ratio of overlapping samples between each window (default 50%).
        :param directory (str): Path to save the spectrogram.
        :param filename (str): Name of the final spectrogram.
        """
        if overlap >= 1:
            print(f'You put a hop value over 1. This has been corrected to have {overlap} as overlap size between window')
            overlap_size = overlap
        else:
            overlap_size = window_size * overlap
    
        stft = librosa.stft(sig, n_fft=int(window_size),
                            hop_length=int(overlap_size), window='hann')  # Compute the STFT
        stft = np.log10(np.abs(stft))  # Adapt the Complex-valued matrix
        fig = plt.figure()
        # plot the spectrogram
        plt.imshow(stft[::-1], aspect='auto',
                   interpolation=None, cmap=cmap, vmin=stft.mean())
        # Remove all the borders around the plot
        plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
        if names:
            folder = names.split('/')[-2]
            create_directory(os.path.join(directory, folder))
            plt.savefig(f'{names}.jpg')
            plt.close()  # Close the figure
            return
        else : 
            return fig # Return the figure
    
    
    def split(df, method, ratio=0.7):
        """
        Split an annotation dataframe into 2 groups with a ratio
        :param df (DataFrame): DataFrame containing the annotation with 2 columns :
        'species' : number between 0 and n; and 'file' : Path of the file that contain the annotation.
        :param ratio (float): Ratio of the annotation in major instead of minor
        :return major_df: DataFrame containing the major part of the annotations.
        :return minor_df: DataFrame containing the minor part of the annotations.
        """
        classes = df.species.unique()
        n_class = classes.size
        # Initialize 2 counters
        major_count = pd.DataFrame(np.zeros((n_class, 1)), index=classes)
        minor_count = major_count.copy()
        # Initialize 2 DataFrame
        major_df = pd.DataFrame()
        minor_df = pd.DataFrame()
        # Go throught the differents classes
        for _, specie in enumerate(classes):
            try:
                data = df.groupby('species').get_group(specie)
            except KeyError:
                print(
                    f"Warning: The species '{specie}' was not found in the DataFrame.")
                continue
            except Exception as error:
                print(
                    f"An unexpected error occurred while processing the species '{specie}': {error}")
                continue
            # Add a first annotation in both major and minor DataFrame
            if major_count.loc[specie].iloc[0] == 0:
                # Random sampling of 1 annotation
                annotation = data.sample(1).file.iloc[0]
                mask = df.file == annotation
                major_count = major_count.add(
                    df[mask].species.value_counts(), axis=0).fillna(0)
                major_df = pd.concat([major_df, df[mask]])
                # Removing the annotation from the original DataFrame
                df = df[~mask]
            if minor_count.loc[specie].iloc[0] == 0:
                # Random sampling of 1 annotation
                annotation = data.sample(1).file.iloc[0]
                mask = df.file == annotation
                minor_count = minor_count.add(
                    df[mask].species.value_counts(), axis=0).fillna(0)
                minor_df = pd.concat([minor_df, df[mask]])
                # Removing the annotation from the original DataFrame
                df = df[~mask]
        # Go throught df to do the split until no data left in df
        while len(df):
            # find the least common species in the DataFrame
            min_esp = df.groupby('species').count().file.idxmin()
            # find all the data of this species
            data = df.groupby('species').get_group(min_esp)
            # Random sampling of 1 annotation
            annotation = data.sample(1).file.iloc[0]
            # Check the actual ratio
            if (major_count.loc[min_esp]/(minor_count.loc[min_esp] +
                              major_count.loc[min_esp]))[0] > ratio:
                # between major and minor
                minor_count.loc[min_esp] += df[df.file ==
                                               annotation].groupby('species').count().iloc[0].file
                minor_df = pd.concat([minor_df, df[df.file == annotation]])
            else:
                major_count.loc[min_esp] += df[df.file ==
                                               annotation].groupby('species').count().iloc[0].file
                major_df = pd.concat([major_df, df[df.file == annotation]])
            # Removing the annotation from the original DataFrame
            df = df[df.file != annotation]
        res = major_count/(minor_count + major_count)
        res.columns = [f'{method} ratio']
        if method == 'train':
            other = 'val'
        else:
            other = 'test'
        res[f'{other} ratio'] = 1 - res[res.columns[0]]
        res = res.reset_index().rename(columns={'index': 'class'})
        print('\n', res)
        return major_df, minor_df
    
    
    def open_file(path):
        """
        Open a file with a path without knowing if suffix is .pkl or .csv
        :param path (str): Path to the file to open or the folder that
        contains all the files to conactenate
        :return df: DataFrame.
        """
        suffix = path.split('.')[-1]  # Extract the suffix of the file
        if suffix == 'pkl':
            print('Try to load as pickle...')
            df = pd.read_pickle(path)
        elif suffix == 'csv':
            if path.split('/')[-1] == 'species_list.csv':
                return pd.DataFrame()
            else:
                print('Try to load as csv...')
                try:
                    df = pd.read_csv(path)
                except Exception:
                    df = pd.read_csv(path, sep=';')
        elif suffix == 'nc':
            print('Try to load as netcdf...')
            ds = xr.load_dataset(path)
            df = ds.to_dataframe()
        elif suffix == 'txt':
            print('Try to load as txt...')
            df = pd.read_csv(path, sep='\t')
        elif suffix == 'wav' or suffix == 'WAV' or suffix == 'Wav':
            print("Wav files can't be load...")
            return pd.DataFrame()
        else:
            print('Collect all files on a folder...')
            df = pd.DataFrame(glob.glob(os.path.join(path, '*'),
                              recursive=True), columns=['Path'])
        return df
    
    
    def process_json_files(json_dir, img_dir, directory):
        """
        Process json annotation and add data information
        :param json_dir (str): Path to the .json files
        :param img (str): Path to the .jpg files
        :param directory (str): Directory to save the results
        """
        json_files = [f for f in os.listdir(json_dir) if f.endswith(
            '.json')]  # Collect all the .json file int the path
    
        for json_file in json_files:  # Process each file one by one
            json_path = os.path.join(json_dir, json_file)
            img_path = os.path.join(img_dir, json_file.replace('.json', '.jpg'))
    
            if not os.path.exists(img_path):
                continue
            try:
                with open(img_path, 'rb') as img_file:  # Load the images
                    image_data = base64.b64encode(img_file.read()).decode(
                        'utf-8')  # Collect the images pixels information
                    # and encode into the correct format
            except FileNotFoundError:
                continue
    
            with open(json_path, 'r', encoding='utf-8') as f:
                json_data = json.load(f)
    
            json_data['imageData'] = image_data
            json_data['imagePath'] = img_path
    
            output_path = os.path.join(directory, json_file)
            with open(output_path, 'w', encoding='utf-8') as f:
                json.dump(json_data, f, indent=4)
    
    
    def labelme2yolo(labelme_annotation_path, yolo_directory):
        """
        Process json annotation and convert to labelme format
        :param labelme_annotation_path (str): Path to the .json files
        :param yolo_directory (str): Directory to save the .txt files
        """
        # Load LabelMe annotation
        image_id = Path(labelme_annotation_path).stem
        with open(labelme_annotation_path, 'r', encoding='utf-8') as labelme_annotation_file:
            labelme_annotation = json.load(labelme_annotation_file)
    
        # YOLO annotation and image paths
        yolo_annotation_path = os.path.join(
            yolo_directory, 'labels', f'{image_id}.txt')
        yolo_image_path = os.path.join(
            yolo_directory, 'images/all', f'{image_id}.jpg')
    
        with open(yolo_annotation_path, 'w', encoding='utf-8') as yolo_annotation_file:
            yolo_image_data = base64.b64decode(labelme_annotation['imageData'])
    
            # Write YOLO image
            with open(yolo_image_path, 'wb') as yolo_image_file:
                yolo_image_file.write(yolo_image_data)
    
            # Write YOLO image annotation
            for shape in labelme_annotation['shapes']:
                if shape['shape_type'] != 'rectangle':
                    print(
                      f'Invalid type `{shape["shape_type"]}` in annotation `{labelme_annotation_path}`'
                      )
                    continue
    
                label = shape['label']
    
                # shape['points'] format : [[x1,y1],[x2,y2]...] #
                scale_width = 1.0 / labelme_annotation['imageWidth']
                scale_height = 1.0 / labelme_annotation['imageHeight']
                width = abs(shape['points'][1][0] -
                            shape['points'][0][0]) * scale_width
                height = abs(shape['points'][1][1] -
                             shape['points'][0][1]) * scale_height
    
                x = min(shape['points'][0][0], shape['points']
                        [1][0]) * scale_width + width / 2
                y = min(shape['points'][0][1], shape['points']
                        [1][1]) * scale_height + height / 2
                if x+width/2 > 1 or y+height/2 > 1:
                    print(
                        f'Error with bounding box values over 1 in file {yolo_image_file}')
                annotation_line = f'{label} {x} {y} {width} {height}\n'
                yolo_annotation_file.write(annotation_line)
    
    
    def prepare_dataframe(df, args):
        """
        Prepare the annotation in time frequency
        :param df (DataFrame): DataFrame that contains the annotations' informations
        :param args (args): Argument
        :return df (DataFrame): Prepared DataFrame
        :return colors (list): Color for each class
        :return species_list (list): List of each class
        """
        df.rename(columns={'Begin Time (s)': 'start', 'End Time (s)': 'stop',
                           'Low Freq (Hz)': 'min_freq', 'High Freq (Hz)': 'max_freq', 'Annotation' : 'species'}, 
                           inplace=True)
    
        species_list = df.groupby('species').size().sort_values(
            ascending=False).reset_index()
    
        df['d_annot'] = df.stop - df.start
        df['midl'] = (df.stop + df.start) / 2
        df['midl_y'] = (df.min_freq+df.max_freq)/2
    
        df = df[df.d_annot < args.duration]
        df = df.reset_index()
    
        return df, species_list
    
    
    def detection2time_freq(annotations_folder, duration, outdir, sr, names, wav, raven):
        """
        Collect all .txt detection and get time and frequency informations
        :param annotations_folder (str): Path to the .json files
        :param duration (int): Directory to save the .txt files
        :param outfir (str): Directory to save the .txt files
        :param sr (int): Directory to save the .txt files
        :param names (str): Directory to save the .txt files
        """
        today = date.today()
        out_file = f'YOLO_detection_{today.day}_{today.month}_freq_{sr}_duration_{duration}.nc'
    
        # Load and process data
        df = pd.concat({f: pd.read_csv(os.path.join(annotations_folder, f),
                         sep=' ', names=['class', 'x', 'y', 'w', 'h', 'conf'])
                         for f in tqdm(os.listdir(annotations_folder),
                                         desc="Processing", ascii='░▒▓█')},
                                         names=['file'])
    
        df = df.reset_index(level=[0])
        df = df.reset_index(drop=True)
        # Collect start time of the spectrogram
        df['offset'] = df.file.str.split('_').str[-1].str.split('.').str[0]
        # Remove all the path to keep the file name
        df.file = ['.'.join(x.file.split('.')[:-1]) +
                   '.WAV' for i, x in df.iterrows()]
    
        if len(names) == 0:
            total = len(df.groupby('species').size()) - 1
            print(
                f'Consider that no names has been put into : {names} list, so it will automatically be from 0 to {total}')
            names = np.arange(0, total+1).tolist()
        df['species'] = df['class'].apply(lambda x: names[int(x)])
    
        df['pos'] = (df['x'] * duration) + df['class'].astype(int)
        df['Low Freq (Hz)'] = (1 - df['y']) * (sr / 2) - (df['h'] * (sr / 2)) / 2
        df['High Freq (Hz)'] = (1 - df['y']) * (sr / 2) + (df['h'] * (sr / 2)) / 2
        df['Begin Time (s)'] = df['pos'] - (df['w'] * duration) / 2
        df['End Time (s)'] = df['pos'] + (df['w'] * duration) / 2
        df['duration'] = df['End Time (s)'] - df['Begin Time (s)']
    
        # Extract the annotation of each file and save them into a .txt file
        if raven:
            folder = 'Raven_annotation'
            create_directory(os.path.join(outdir, folder))
            # Collect all the original filename
            files = pd.DataFrame(os.listdir(wav), columns = ['file_origin'])
            files['file'] = files.file_origin.str.split('.').str[0]
            # Remove the time information in the detection filename
            df['filename'] = ['_'.join(file.split('_')[:-1]) for file in df.file]
            print('\nSaving the Raven Annotations files...\n')
            for file, grp in tqdm(df.groupby('filename'),
                                  total=len(df.groupby('filename').size()),
                                  desc="Processing", ascii='░▒▓█'):
                # Check if the filename match with an original file .wav
                for _, row in files.iterrows():
                    if row.file in file:
                        file = row['file_origin']
    
                        file = '.'.join(file.split('.')[:-1])
                        filename_raven = f'{file}.Table1.txt'
                        dir_raven = os.path.join(outdir, folder, filename_raven)
                        grp.to_csv(dir_raven, sep='\t', index=False)
            print(f'Annotation saved in <{outdir}> as {folder}')
    
        dir_path = os.path.join(outdir, out_file)
        return df, dir_path
    
    
    def split_annotations(df, duration=8):
        """
        Split the annotations into multiple segments if they span across different spectrograms.
        :param df (DataFrame): DataFrame containing the annotations with 'start' and 'stop' columns.
        :param duration (int): Duration of a single spectrogram.
        :return (DataFrame): DataFrame containing the split annotations.
        """
        splited_annotations = []
    
        for _, row in df.iterrows():
            start = row['start']
            end = row['stop']
    
            while start < end:
                # Calculate the end of the current spectrogram chunk
                current_chunk_end = (start // duration + 1) * duration
    
                if end > current_chunk_end:
                    if (current_chunk_end - start) >= (end - start) * 0.5:
                        # Split the annotation
                        splited_annotations.append(
                            {'start': start, 'stop': current_chunk_end})
                        start = current_chunk_end
                    else:
                        # If the remaining segment is less than half of the annotation
                        # only keep the longest part
                        splited_annotations.append(
                            {'start': current_chunk_end, 'stop': end})
                        break
                else:
                    # This annotation fits within the current chunk
                    splited_annotations.append({'start': start, 'stop': end})
                    break
    
        return pd.DataFrame(splited_annotations)
    
    
    def get_box_shape(info, im):
        """
        Get the pixels information to place the bounding box in the image.
        :param im (array): Image array in cv2.cvtColor(im, cv2.COLOR_BGR2RGB) format
        :param annotation (DataFrame): DataFrame with x, y, width, and height information
        :return shp1 (tuple): Positions in x and y of the top left point in pixels
        :return shape1 (tuple): Positions in x and y of the top left point in ratio
        :return shp4 (tuple): Positions in x and y of the bottom right point in pixels
        :return shape4 (tuple): Positions in x and y of the bottom right point in ratio
        """
        annotation, _ = info
        H, W = im.shape[0], im.shape[1]
        x, y, w, h = annotation.x * W, annotation.y * \
            H, annotation.width * W, annotation.height * H
    
        shape1 = (int(x - 0.5 * w), int(y + 0.5 * h))
        shape4 = (int(x + 0.5 * w), int(y - 0.5 * h))
    
        shp1 = (shape4[0], shape4[1])
        shp4 = (shape4[0], shape4[1])
    
        return shp1, shape1, shp4, shape4
    
    
    def get_set_info(entry):
        """
        Check if the dataset is balanced
        :param entry (list): List containing train, val and test dataset
        :return state (str): State of the dataset : balanced or unbalanced
        :return proposition (str): If balance, just str, else a str + list of weights
        """
        # Check entry size
        if len(entry) == 2:
            train, val = entry[0], entry[1]
            dataset = pd.concat([train, val])  # Concat the datasets into one
        else:
            train, val, test = entry[0], entry[1], entry[2]
            dataset = pd.concat([train, val, test])  # Concat the datasets into one
    
        # Check whether the minor class is under-represented
        if dataset.groupby('species').size().min() < dataset.groupby('species').size().max()*.3:
            state = 'unbalanced'
            # Calculate the multiple factor to get a balanced dataset
            weights_list = (dataset.groupby('species').size().max() /
                            dataset.groupby('species').size()).tolist()
            proposition = f'\u274C you should use positive class weights in the custom_hyp.yaml cls_pw, add this {weights_list}'
        else:
            state = 'balanced'
            proposition = '\u2705 this is good'
        return state, proposition