"""Define all the function that are used in the repository"""

import argparse
import glob
import shutil
import os
import base64
import json
from datetime import date
from pathlib import Path
import librosa
import random

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
from scipy import signal
from tqdm import tqdm


def arg_directory(path):
    """
    Check if the path given in args is a real directory or not.
    :param path (str): The path to a folder.
    :return str: Return the path if correct, error if doesn't exist.
    """
    if os.path.isdir(path):
        return path
    raise argparse.ArgumentTypeError(f'{path} is not a valid path')


def copy_files_to_directory(file_list, source, directory, suffix):
    """
    Copy files from a directory to another one
    :param file_list (str): List of the file to copy.
    :param source_dir (str): Directory of the original files.
    :param directory (str): Directory to copy the new files.
    :param suffix (str): Suffix of the files
    """
    for filename in file_list:
        source_file = os.path.join(source, f'{filename}.{suffix}')
        destination_path = os.path.join(directory, f'{filename}.{suffix}')
        shutil.copy2(source_file, destination_path)


def create_directory(directory):
    """
    Create a directory if not exists.
    :param directory (str): Directory to create.
    """
    # Check if directory exists
    if not os.path.exists(directory):
        os.mkdir(directory)
        print(f'`{directory}` has been created')


def signal_processing(sig, rf, fs, high=None, low=None):
    """
    Resample the signal and apply high pass and low pass filter.

    :param sig (array): Signal.
    :param rf (int): Resampling Frequency.
    :param fs (int): Original Sampling Frequency.
    :param high (int): High pass filter value (default None).
    :param low (int): Low pass filter value (default None).
    :return array: Processed signal.
    """
    # Check if resampling frequency is different than sampling frequency
    if not rf:
        rf = fs
    if rf != fs:
        sig = signal.resample(sig, int(len(sig) * rf / fs)
                              )  # Resample the signal

    # Apply high pass filter if specified
    if high:
        # Create high pass filter
        high_pass = signal.butter(2, high / (rf / 2), 'hp', output='sos')
        sig = signal.sosfilt(high_pass, sig)  # Apply high pass filter

    # Apply low pass filter if specified
    if low:
        # Create low pass filter
        low_pass = signal.butter(1, low / (rf / 2), 'lp', output='sos')
        sig = signal.sosfilt(low_pass, sig)  # Apply low pass filter

    return sig


def create_spectrogram(sig, directory, names, cmap, minimum, window_size=1024, overlap=.5):
    """
    Create a spectrogram STFT with hanning window and save it into a directory

    :param sig (array): Signal to process.
    :param directory (str): Path to save the spectrogram.
    :param names (str): Name of the final spectrogram.
    :param cmap (str): Name of the colormap for matplotlib.
    :param minimum (str): If True minimum of spectrogram imshow will be stft.mean(), else stft.min()
    :param window_size (int): Number of sample / STFT window.
    :param overlap (float): Ratio of overlapping samples between each window (default 50%).
    """
    if overlap >= 1:
        hop = window_size - overlap
        print(f'You put a hop value over 1. This has been corrected to have {overlap} as overlap size between window')
    else:
        hop = window_size * (1-overlap) # As hop length is the number of audio samples between adjacent STFT columns

    stft = librosa.stft(sig, n_fft=int(window_size),
                        hop_length=int(hop), window='hann')  # Compute the STFT
    stft = np.log10(np.abs(stft))  # Adapt the Complex-valued matrix
    fig = plt.figure()
    if minimum == 'True':
        vmin = stft.mean()
    else:
        vmin = stft.min()
    # plot the spectrogram
    plt.imshow(stft[::-1], aspect='auto',
               interpolation=None, cmap=cmap, vmin=vmin)
    # Remove all the borders around the plot
    plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
    if names:
        folder = names.split('/')[-2]
        create_directory(os.path.join(directory, folder))
        plt.savefig(f'{names}.jpg')
        plt.close()  # Close the figure
        return
    else : 
        return fig # Return the figure


def split(df, method, ratio=0.7):
    """
    Split an annotation dataframe into 2 groups with a ratio
    :param df (DataFrame): DataFrame containing the annotation with 2 columns :
    'species' : number between 0 and n; and 'file' : Path of the file that contain the annotation.
    :param ratio (float): Ratio of the annotation in major instead of minor
    :return major_df: DataFrame containing the major part of the annotations.
    :return minor_df: DataFrame containing the minor part of the annotations.
    """
    classes = df.species.unique()
    n_class = classes.size
    # Initialize 2 counters
    major_count = pd.DataFrame(np.zeros((n_class, 1)), index=classes)
    minor_count = major_count.copy()
    # Initialize 2 DataFrame
    major_df = pd.DataFrame()
    minor_df = pd.DataFrame()
    # Go throught the differents classes
    for _, specie in enumerate(classes):
        try:
            data = df.groupby('species').get_group(specie)
        except KeyError:
            print(
                f"Warning: The species '{specie}' was not found in the DataFrame.")
            continue
        except Exception as error:
            print(
                f"An unexpected error occurred while processing the species '{specie}': {error}")
            continue
        # Add a first annotation in both major and minor DataFrame
        if major_count.loc[specie].iloc[0] == 0:
            # Random sampling of 1 annotation
            annotation = data.sample(1).file.iloc[0]
            mask = df.file == annotation
            major_count = major_count.add(
                df[mask].species.value_counts(), axis=0).fillna(0)
            major_df = pd.concat([major_df, df[mask]])
            # Removing the annotation from the original DataFrame
            df = df[~mask]
        if minor_count.loc[specie].iloc[0] == 0:
            # Random sampling of 1 annotation
            annotation = data.sample(1).file.iloc[0]
            mask = df.file == annotation
            minor_count = minor_count.add(
                df[mask].species.value_counts(), axis=0).fillna(0)
            minor_df = pd.concat([minor_df, df[mask]])
            # Removing the annotation from the original DataFrame
            df = df[~mask]
    # Go throught df to do the split until no data left in df
    while len(df):
        # find the least common species in the DataFrame
        min_esp = df.groupby('species').count().file.idxmin()
        # find all the data of this species
        data = df.groupby('species').get_group(min_esp)
        # Random sampling of 1 annotation
        annotation = data.sample(1).file.iloc[0]
        # Check the actual ratio
        if (major_count.loc[min_esp]/(minor_count.loc[min_esp] +
                          major_count.loc[min_esp]))[0] > ratio:
            # between major and minor
            minor_count.loc[min_esp] += df[df.file ==
                                           annotation].groupby('species').count().iloc[0].file
            minor_df = pd.concat([minor_df, df[df.file == annotation]])
        else:
            major_count.loc[min_esp] += df[df.file ==
                                           annotation].groupby('species').count().iloc[0].file
            major_df = pd.concat([major_df, df[df.file == annotation]])
        # Removing the annotation from the original DataFrame
        df = df[df.file != annotation]
    res = major_count/(minor_count + major_count)
    res.columns = [f'{method} ratio']
    if method == 'train':
        other = 'val'
    else:
        other = 'test'
    res[f'{other} ratio'] = 1 - res[res.columns[0]]
    res = res.reset_index().rename(columns={'index': 'class'})
    print('\n', res)
    return major_df, minor_df


def split_background(file_list, arguments):
    """
    Randomly split the background images and save them into the differents sets.
    :param file_list (list): List with all the filename of the background.
    :param argument (args): Arguments.
    """
    file_list = ['.'.join(x.split('.')[:-1])
            for num, x in enumerate(file_list)]
    random.shuffle(file_list)
    total = len(file_list)
    if arguments.test:
        r = 0
        t = total/3
        for s in ['train','test','val']:
            source_txt = arguments.path_to_data
            source_img = os.path.join(arguments.path_to_data, '../images/')
            directory_txt = os.path.join(arguments.directory, f'labels/{s}')
            directory_img = os.path.join(arguments.directory, f'images/{s}')
            copy_files_to_directory(file_list[r:t], source_txt, directory_txt, 'txt')
            copy_files_to_directory(file_list[r:t], source_img, directory_img, 'jpg')
            r=t
            t+=t
    else:
        r = 0
        t = total/2
        for s in ['train','val']:
            source_txt = arguments.path_to_data
            source_img = os.path.join(arguments.path_to_data, '../images/')
            directory_txt = os.path.join(arguments.directory, f'labels/{s}')
            directory_img = os.path.join(arguments.directory, f'images/{s}')
            copy_files_to_directory(file_list[r:t], source_txt, directory_txt, 'txt')
            copy_files_to_directory(file_list[r:t], source_img, directory_img, 'jpg')
            r=t
            t+=t

def open_file(path):
    """
    Open a file with a path without knowing if suffix is .pkl or .csv
    :param path (str): Path to the file to open or the folder that
    contains all the files to conactenate
    :return df: DataFrame.
    """
    suffix = path.split('.')[-1]  # Extract the suffix of the file
    if suffix == 'pkl':
        print('Try to load as pickle...')
        df = pd.read_pickle(path)
    elif suffix == 'csv':
        if path.split('/')[-1] == 'species_list.csv':
            return pd.DataFrame()
        else:
            print('Try to load as csv...')
            try:
                df = pd.read_csv(path)
            except Exception:
                df = pd.read_csv(path, sep=';')
    elif suffix == 'nc':
        print('Try to load as netcdf...')
        ds = xr.load_dataset(path)
        df = ds.to_dataframe()
    elif suffix == 'txt':
        print('Try to load as txt...')
        df = pd.read_csv(path, sep='\t')
    elif suffix == 'wav' or suffix == 'WAV' or suffix == 'Wav':
        print("Wav files can't be load...")
        return pd.DataFrame()
    else:
        print('Collect all files in the folder...')
        df = pd.DataFrame(glob.glob(os.path.join(path, '*'),
                          recursive=True), columns=['Path'])
    return df


def process_json_files(json_dir, img_dir, directory):
    """
    Process json annotation and add data information
    :param json_dir (str): Path to the .json files
    :param img (str): Path to the .jpg files
    :param directory (str): Directory to save the results
    """
    json_files = [f for f in os.listdir(json_dir) if f.endswith(
        '.json')]  # Collect all the .json file int the path

    for json_file in json_files:  # Process each file one by one
        json_path = os.path.join(json_dir, json_file)
        img_path = os.path.join(img_dir, json_file.replace('.json', '.jpg'))

        if not os.path.exists(img_path):
            continue
        try:
            with open(img_path, 'rb') as img_file:  # Load the images
                image_data = base64.b64encode(img_file.read()).decode(
                    'utf-8')  # Collect the images pixels information
                # and encode into the correct format
        except FileNotFoundError:
            continue

        with open(json_path, 'r', encoding='utf-8') as f:
            json_data = json.load(f)

        json_data['imageData'] = image_data
        json_data['imagePath'] = img_path

        output_path = os.path.join(directory, json_file)
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(json_data, f, indent=4)


def labelme2yolo(labelme_annotation_path, yolo_directory):
    """
    Process json annotation and convert to labelme format
    :param labelme_annotation_path (str): Path to the .json files
    :param yolo_directory (str): Directory to save the .txt files
    """
    # Load LabelMe annotation
    image_id = Path(labelme_annotation_path).stem
    with open(labelme_annotation_path, 'r', encoding='utf-8') as labelme_annotation_file:
        labelme_annotation = json.load(labelme_annotation_file)

    # YOLO annotation and image paths
    yolo_annotation_path = os.path.join(
        yolo_directory, 'labels', f'{image_id}.txt')
    yolo_image_path = os.path.join(
        yolo_directory, 'images/all', f'{image_id}.jpg')

    with open(yolo_annotation_path, 'w', encoding='utf-8') as yolo_annotation_file:
        yolo_image_data = base64.b64decode(labelme_annotation['imageData'])

        # Write YOLO image
        with open(yolo_image_path, 'wb') as yolo_image_file:
            yolo_image_file.write(yolo_image_data)

        # Write YOLO image annotation
        for shape in labelme_annotation['shapes']:
            if shape['shape_type'] != 'rectangle':
                print(
                  f'Invalid type `{shape["shape_type"]}` in annotation `{labelme_annotation_path}`'
                  )
                continue

            label = shape['label']

            # shape['points'] format : [[x1,y1],[x2,y2]...] #
            scale_width = 1.0 / labelme_annotation['imageWidth']
            scale_height = 1.0 / labelme_annotation['imageHeight']
            width = abs(shape['points'][1][0] -
                        shape['points'][0][0]) * scale_width
            height = abs(shape['points'][1][1] -
                         shape['points'][0][1]) * scale_height

            x = min(shape['points'][0][0], shape['points']
                    [1][0]) * scale_width + width / 2
            y = min(shape['points'][0][1], shape['points']
                    [1][1]) * scale_height + height / 2
            if x+width/2 > 1 or y+height/2 > 1:
                print(
                    f'Error with bounding box values over 1 in file {yolo_image_file}')
            annotation_line = f'{label} {x} {y} {width} {height}\n'
            yolo_annotation_file.write(annotation_line)


def prepare_dataframe(df, args):
    """
    Prepare the annotation in time frequency
    :param df (DataFrame): DataFrame that contains the annotations' informations
    :param args (args): Argument
    :return df (DataFrame): Prepared DataFrame
    :return colors (list): Color for each class
    :return species_list (list): List of each class
    """
    df.rename(columns={'Begin Time (s)': 'start', 'End Time (s)': 'stop',
                       'Low Freq (Hz)': 'min_freq', 'High Freq (Hz)': 'max_freq', 'Annotation' : 'species'}, 
                       inplace=True)

    species_list = df.groupby('species').size().sort_values(
        ascending=False).reset_index()
    species_list.columns = ['species','number']

    df['d_annot'] = df.stop - df.start
    df['midl'] = (df.stop + df.start) / 2
    df['midl_y'] = (df.min_freq+df.max_freq)/2

    df = df[df.d_annot < args.duration]
    df = df.reset_index()

    return df, species_list


def detection2time_freq(annotations_folder, duration, outdir, rf, names, wav, raven):
    """
    Collect all .txt detection and get time and frequency informations
    :param annotations_folder (str): Path to the .txt files
    :param duration (int): Directory to save the .txt files
    :param outdir (str): Directory to save the .txt files
    :param rf (int): Resampling freq.
    :param names (str): names of the classes
    :param wav (str): Path to the wav
    :param raven (int): Save into Raven format or not
    """
    today = date.today()
    out_file = f'YOLO_detection_{today.day}_{today.month}_freq_{rf}_duration_{duration}.nc'

    # Load and process data
    df = pd.concat({f: pd.read_csv(os.path.join(annotations_folder, f),
                     sep=' ', names=['class', 'x', 'y', 'w', 'h', 'conf'])
                     for f in tqdm(os.listdir(annotations_folder),
                                     desc="Processing", ascii='░▒▓█')},
                                     names=['file'])

    df = df.reset_index(level=[0])
    df = df.reset_index(drop=True)
    # Collect start time of the spectrogram
    df['offset'] = df.file.str.split('_').str[-1].str.split('.').str[0]
    # Remove all the path to keep the file name
    df.file = ['.'.join(x.file.split('.')[:-1]) +
               '.WAV' for i, x in df.iterrows()]

    if len(names) == 0:
        total = len(df.groupby('species').size()) - 1
        print(
            f'Consider that no names has been put into : {names} list, so it will automatically be from 0 to {total}')
        names = np.arange(0, total+1).tolist()
    df['Annotation'] = df['class'].apply(lambda x: names[int(x)])

    df['pos'] = (df['x'] * duration) + df['offset'].astype(int)
    df['Low Freq (Hz)'] = (1 - df['y']) * (rf / 2) - (df['h'] * (rf / 2)) / 2
    df.loc[df['Low Freq (Hz)'] < 0, 'Low Freq (Hz)'] = 0
    df['High Freq (Hz)'] = (1 - df['y']) * (rf / 2) + (df['h'] * (rf / 2)) / 2
    df['Begin Time (s)'] = df['pos'] - (df['w'] * duration) / 2
    df['End Time (s)'] = df['pos'] + (df['w'] * duration) / 2
    df['duration'] = df['End Time (s)'] - df['Begin Time (s)']

    # Extract the annotation of each file and save them into a .txt file
    if raven:
        folder = 'Raven_annotation'
        create_directory(os.path.join(outdir, folder))
        # Collect all the original filename
        files = pd.DataFrame(os.listdir(wav), columns = ['file_origin'])
        files['file'] = files.file_origin.str.split('.').str[0]
        # Remove the time information in the detection filename
        df['filename'] = ['_'.join(file.split('_')[:-1]) for file in df.file]
        print('\nSaving the Raven Annotations files...\n')
        for file, grp in tqdm(df.groupby('filename'),
                              total=len(df.groupby('filename').size()),
                              desc="Processing", ascii='░▒▓█'):
            # Check if the filename match with an original file .wav
            for _, row in files.iterrows():
                if row.file in file:
                    file = row['file_origin']

                    file = '.'.join(file.split('.')[:-1])
                    filename_raven = f'{file}.Table1.txt'
                    dir_raven = os.path.join(outdir, folder, filename_raven)
                    grp[
                        ['Annotation','pos','Low Freq (Hz)','High Freq (Hz)', 
                        'Begin Time (s)','End Time (s)','duration','filename']
                        ].to_csv(dir_raven, sep='\t', index=False)
        print(f'Annotation saved in <{outdir}> as {folder}')

    dir_path = os.path.join(outdir, out_file)
    return df, dir_path


def split_annotations(df, duration=8):
    """
    Split the annotations into multiple segments if they span across different spectrograms.
    :param df (DataFrame): DataFrame containing the annotations with 'start' and 'stop' columns.
    :param duration (int): Duration of a single spectrogram.
    :return (DataFrame): DataFrame containing the split annotations.
    """
    splited_annotations = []

    for _, row in df.iterrows():
        start = row['start']
        end = row['stop']

        while start < end:
            # Calculate the end of the current spectrogram chunk
            current_chunk_end = (start // duration + 1) * duration

            if end > current_chunk_end:
                # Check for the first part of the annotation
                if (current_chunk_end - start) > (end - start) * 0.2 and (current_chunk_end - start) < (end - start) * 0.80:
                    # Split the annotation into 2 new annotations
                    new_row = row.copy()
                    row['stop'] = current_chunk_end
                    new_row['start'] = current_chunk_end
                    splited_annotations.append(pd.DataFrame(row).T)
                    splited_annotations.append(pd.DataFrame(new_row).T)
                    break
                elif (current_chunk_end - start) <= (end - start) * 0.2:
                    # If the first segment is less than 20% of the annotation
                    # only keep the second part
                    row['start'] = current_chunk_end
                    splited_annotations.append(pd.DataFrame(row).T)
                    break
                else:
                    # If the first segment is more than 80% of the annotation
                    # only keep the first part    
                    row['stop'] = current_chunk_end                
                    splited_annotations.append(pd.DataFrame(row).T)
                    break
            else:
                # This annotation fits within the current chunk
                splited_annotations.append(pd.DataFrame(row).T)
                break

    return pd.concat(splited_annotations)


def get_box_shape(info, im):
    """
    Get the pixels information to place the bounding box in the image.
    :param im (array): Image array in cv2.cvtColor(im, cv2.COLOR_BGR2RGB) format
    :param annotation (DataFrame): DataFrame with x, y, width, and height information
    :return shp1 (tuple): Positions in x and y of the top left point in pixels
    :return shape1 (tuple): Positions in x and y of the top left point in ratio
    :return shp4 (tuple): Positions in x and y of the bottom right point in pixels
    :return shape4 (tuple): Positions in x and y of the bottom right point in ratio
    """
    annotation, _ = info
    H, W = im.shape[0], im.shape[1]
    x, y, w, h = annotation.x * W, annotation.y * \
        H, annotation.width * W, annotation.height * H

    shape1 = (int(x - 0.5 * w), int(y + 0.5 * h))
    shape4 = (int(x + 0.5 * w), int(y - 0.5 * h))

    shp1 = (shape4[0], shape4[1])
    shp4 = (shape4[0], shape4[1])

    return shp1, shape1, shp4, shape4


def get_set_info(entry):
    """
    Check if the dataset is balanced
    :param entry (list): List containing train, val and test dataset
    :return state (str): State of the dataset : balanced or unbalanced
    :return proposition (str): If balance, just str, else a str + list of weights
    """
    # Check entry size
    if len(entry) == 2:
        train, val = entry[0], entry[1]
        dataset = pd.concat([train, val])  # Concat the datasets into one
    else:
        train, val, test = entry[0], entry[1], entry[2]
        dataset = pd.concat([train, val, test])  # Concat the datasets into one

    # Check whether the minor class is under-represented
    if dataset.groupby('species').size().min() < dataset.groupby('species').size().max()*.3:
        state = 'unbalanced'
        # Calculate the multiple factor to get a balanced dataset
        weights_list = (dataset.groupby('species').size().max() /
                        dataset.groupby('species').size()).tolist()
        proposition = f'\u274C you should use positive class weights in the custom_hyp.yaml cls_pw, add this {weights_list}'
    else:
        state = 'balanced'
        proposition = '\u2705 this is good'
    return state, proposition


def correct_box(x,w):
    """
    Apply correction if there is an overflow on the annotation box
    :param x (float): Ratio of the center of the box
    :return w (float): Ratio of the size of the box
    :return x,w (float): Corrected values
    """
    # Get the beggining and the end of the box
    x0, x1 = (x - (w / 2)), (x + (w / 2))

    # Check the overflow
    if x1 > 1 and x0 > 0:
        w = 1 - x0
        x = x0 + w/2
    elif x0 < 0 and x1 < 1 :
        w = x1
        x = w/2
    elif x0 < 0 and x1 > 1:
        w = 1
        x = 0.5

    return abs(x),abs(w)