"""Converts Raven format dataframe annotations to YOLO format"""

import argparse
import random
import os
import sys
import numpy as np
import cv2
import utils
import matplotlib.pyplot as plt
import soundfile as sf
from p_tqdm import p_map
from tqdm import tqdm
import pandas as pd


def main(entry, arguments, species_list):
    """
    Precess the annotation to get the .jpg spectrogram and the .txt annotation file
    :param x (tuple): Enumerate number, [filename, group] per file
    :param arguments (args): Arguments
    :param species_list (list): List with corresponding species
    """
    _, (filename, grp) = entry

    try:
        info = sf.info(filename)  # Collection recording information
        file_duration, fs = info.duration, info.samplerate
    except Exception as error:
        print(f'`{filename}` cannot be open... : {error}')
        return

    # create the time list between 0 and 1000 * duration
    # Create the list of all possible offset to compute spectrogram
    offset_list = np.arange(
        0, file_duration, arguments.duration - arguments.overlap)
    grp = utils.split_annotations(grp, arguments.duration)
    grp = grp.reset_index(drop=True)
    while len(grp) != 0:
        # collect all the data between the offset and duration-overlap
        if len(offset_list) >= 1:
            table = grp[grp.start < offset_list[0] + arguments.duration]
        else:
            continue

        # create an empty dataframe
        annotation = pd.DataFrame(columns=['id', 'x', 'y', 'width', 'height'])

        # if no data for this period, go to the next one
        if len(table) == 0:
            offset_list = offset_list[1:]
            continue

        offset = offset_list[0]  # take initial time for offset

        name = str(grp.iloc[0].Path.replace(
            '/', '_').replace('.', '_') + '_' + str(offset))

        sig, fs = sf.read(filename, start=int(
            offset*fs), stop=int((offset+arguments.duration)*fs), always_2d=True)  # Load the signal
        sig = sig[:, 0]  # Only take channel 0
        if not arguments.rf:
            arguments.rf = fs
        # Apply resample and low/high pass filter
        sig = utils.signal_processing(
            sig, rf=arguments.rf, fs=fs, high=arguments.high, low=arguments.low)
        fig = utils.create_spectrogram(
            sig, arguments.directory, names=None,
            window_size=arguments.window,
            overlap=arguments.hop, cmap=arguments.cmap, minimum=arguments.minimum)

        for _, row in table.iterrows():
            specie = row.species
            x_pxl = (row.midl - offset) / \
                arguments.duration  # take x value in pixels
            width_pxl = (row.stop - row.start) / \
                arguments.duration  # take width value in pixels

            # take y value in pixels
            y_pxl = 1 - (row.midl_y / (arguments.rf / 2))
            height_pxl = (row.max_freq - row.min_freq) / \
                (arguments.rf / 2)  # take height value in pixels

            # Correction if the boxes are corrupted (> 1 or < 0)
            x_pxl, width_pxl = utils.correct_box(x_pxl, width_pxl)
            y_pxl, height_pxl = utils.correct_box(y_pxl, height_pxl)
            if x_pxl > 1 and width_pxl < 0.05:
                continue
            # Store the annotation in a DataFrame
            new_table = pd.DataFrame([[str(species_list[species_list.species == specie].index[0]),
                                       x_pxl, y_pxl, width_pxl, height_pxl]],
                                     columns=['id', 'x', 'y', 'width', 'height'])

            annotation = annotation.dropna(axis=1, how='all')
            new_table = new_table.dropna(axis=1, how='all')                       
            annotation = pd.concat([annotation, new_table])

        grp = grp.drop(table.index)

        name_file = os.path.join(arguments.directory,
                                 'labels', f'{name}.txt')
        # Create all the folder
        for folder in ['images', 'labels', 'images/all', 'annotated_images']:
            utils.create_directory(os.path.join(
                arguments.directory, folder))
        for specie_num in species_list[species_list.columns[0]]:
            utils.create_directory(os.path.join(
                arguments.directory, 'images', str(specie_num)))
        # Save the images and annotation
        plt.savefig(os.path.join(arguments.directory,
                    'images', 
                    str(species_list[species_list.species ==specie].species.iloc[0]), 
                    f'{name}.jpg'))
        annotation.to_csv(name_file, sep=' ', header=False, index=False)
        plt.savefig(os.path.join(arguments.directory, 'images', 'all',
                                 f'{name}.jpg'))
        plt.close()

        # Add annotation to the images in another folder
        image = cv2.imread(
            os.path.join(arguments.directory, 'images', 'all', f'{name}.jpg'))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # For each annotation in a spectrogram
        for num, annot in annotation.iterrows():
            shp1, shape1, shp4, shape4 = utils.get_box_shape(
                [annot, num], image)

            text_shape = shp1[0], shp1[1] - 5
            label = annot['id']

            # Add the annotation into the images as a rectangle
            cv2.rectangle(image, pt1=shape1, pt2=shape4,
                          color=colors[species_list[species_list.species ==
                                                    specie].index[0]],
                          thickness=1)
            cv2.rectangle(
                image, pt1=shp1, pt2=shp4,
                color=colors[species_list[species_list.species ==
                                          specie].index[0]],
                thickness=-1)
            # Add the label associated
            cv2.putText(image, label, text_shape,
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
        plt.imshow(image)
        plt.subplots_adjust(top=1, bottom=0, left=0,
                            right=1)  # Remove the border
        plt.savefig(
            os.path.join(arguments.directory, 'annotated_images', f'{name}.jpg'))
        plt.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                     description='Create .txt and .jpg to each annotation '
                                     'from a csv')
    parser.add_argument('filename_path', type=str,
                        help='Path/name of the folder/file containing the annotations. If a file '
                        'use Raven format and add a Path columns with the path to the '
                        '.wav files')
    parser.add_argument('path_to_data', type=utils.arg_directory,
                        help='Path of the folder that contains the recordings')
    parser.add_argument('directory', type=utils.arg_directory,
                        help='Directory to which spectrograms and .txt files will be stored')
    parser.add_argument('--duration', type=int,
                        help='Duration for each spectrogram', default=8)
    parser.add_argument('--cmap', type=str, choices=['jet','viridis','cividis'],
                        help='Colormap of the Spectrograms', default='viridis')
    parser.add_argument('--overlap', type=float,
                        help='Overlap in seconds between 2 spectrograms', default=0)
    parser.add_argument('--rf', type=int, help='Frequency Resampling', default=None)
    parser.add_argument('--window', type=int, help='Window size for the Fourier Transform', 
                        default=1024)
    parser.add_argument('--hop', type=float, help='Ratio of hop in window : 50%% = 0.5', 
                        default=.5)
    parser.add_argument('--cpu', type=int, help='To speed up the process, write 2 or more', 
                        default=1)
    parser.add_argument('--high', type=int,
                        help='High Pass Filter value in Hz', default=10)
    parser.add_argument('--low', type=int,
                        help='Low Pass Filter value in Hz', default=None)
    parser.add_argument('--test', action='store_const', const=1,
                        help='Split into train/test/val. 1 - Ratio / 2 for test and'
                        ' same for validation', default=None)
    parser.add_argument('--minimum', type=str,
                        help='If True, vmin will be stft.mean(), else stft.min()', default=True)
    args = parser.parse_args()

    # Load the data and put it into a DataFrame
    df = utils.open_file(args.filename_path)
    suffix = input('Which suffix for your recording data? [wav, WAV, Wav, flac, mp3, <other>] : ')

    if len(df.columns) == 1:
        final = []
        for file, _ in df.groupby('Path'):
            new_df = utils.open_file(file)
            if len(new_df) >= 1:
                new_df['Path'] = os.path.join(args.path_to_data, str(file.split('/')[-1].split('.Table')[0]+f'.{suffix}'))
                final.append(new_df)
            else:
                continue
        df = pd.concat(final)
    elif 'Path' not in df.columns:
        df['Path'] = os.path.join(args.path_to_data, args.filename_path.split('/')[-1].split('.Table')[0]+f'.{suffix}')
    df, species = utils.prepare_dataframe(df, args)

    colors = [(random.randint(0, 255), random.randint(0, 255),
               random.randint(0, 255)) for _ in range(len(species))]

    species.to_csv(os.path.join(
        args.directory, 'species_list.csv'), index=False)
    if args.cpu == 1:
        for i in tqdm(enumerate(df.groupby('Path')), total=len(df.groupby('Path')),
                      desc="Processing", ascii='░▒▓█'):
            main(i, args, species)
    else:
        args = [args for i in range(len(df.groupby('Path')))]
        species = [species for i in range(len(df.groupby('Path')))]
        p_map(main, enumerate(df.groupby('Path')), args,
              species, num_cpus=args[0].cpu, total=len(df.groupby('Path')))
        args = args[0]
    print('saved to', args.directory)

    if not args.test:
        # Ask user if the script split the data or not
        SPLIT = input(
            'Do you want to split your data into a random train/test/val ? [Y/N] : ')
    else:
        SPLIT = 'Y'

    if SPLIT == 'Y':
        print('The train set will be 70%, val set 15% and test set 15%')

        # Get the path of the current script
        path = os.path.abspath(os.path.dirname( __file__ ))
        script = os.path.join(path, 'get_train_val.py')
        data_path = os.path.join(path, args.directory, 'labels')
        directory_path = os.path.join(path, args.directory)

        # Create the directory path if not exists
        utils.create_directory(directory_path)
        try:
            print(f'Train saved in {directory_path}\n')
            # Run the split command
            os.system(f'{sys.executable} {script} {data_path} {directory_path} -r 0.7  --test')

        except Exception as error:
            print(error)