"""Separates training and validation datasets in a balanced manner"""

import argparse
import os
import pandas as pd
import utils
import glob

from tqdm import tqdm


def export_split(argument, entry, path, directory):
    """
    Export the annotations' sets
    :param entry (list): Minor set = [0], Major set = [1]
    :param path (str): Path to take the labels files
    :param directory (str): Directory to save the data
    """
    train_set = entry[0]
    val_set = entry[1]

    for folder in ['images', 'labels', 'images/train', 'images/val', 'images/test',
                    'labels/train', 'labels/val', 'labels/test']:
        utils.create_directory(os.path.join(directory, folder))

    if argument.test:
        test_set = entry[2]
        test_set.file = ['.'.join(x.split('/')[-1].split('.')[:-1])
                         for num, x in enumerate(test.file)]

        utils.copy_files_to_directory(test_set.file, path, os.path.join(
            directory, 'labels/test'), 'txt')
        utils.copy_files_to_directory(test_set.file, os.path.join(
            path, '../images/all'), os.path.join(directory, 'images/test'), 'jpg')

    val_set.file = ['.'.join(x.split('/')[-1].split('.')[:-1])
                    for _, x in enumerate(val.file)]
    train_set.file = ['.'.join(x.split('/')[-1].split('.')[:-1])
                      for _, x in enumerate(train_set.file)]

    # Copy the validation set into the folder
    utils.copy_files_to_directory(val_set.file, path, os.path.join(
        directory, 'labels/val'), 'txt')
    utils.copy_files_to_directory(val_set.file, os.path.join(
        path, '../images/all'), os.path.join(directory, 'images/val'), 'jpg')
    # Copy the trainning set into the folder
    utils.copy_files_to_directory(train_set.file, path, os.path.join(
        directory, 'labels/train'), 'txt')
    utils.copy_files_to_directory(train_set.file, os.path.join(
        path, '../images/all'), os.path.join(directory, 'images/train'), 'jpg')

    try:
        species_list = pd.read_csv(os.path.join(path, '../species_list.csv'))
    except FileNotFoundError:
        print('No species list detected, please add it to : ',
              os.path.join(directory, 'custom_data.yaml'))

    with open(os.path.join(directory, 'custom_data.yaml'), 'w', encoding='utf-8') as f:
        if argument.test == 1:
            f.write(f'test: {os.path.join(directory, "images/test")}\n')
        f.write(f'train: {os.path.join(directory, "images/train")}\n')
        f.write(f'val: {os.path.join(directory, "images/val")}\n')
        f.write(f'nc: {len(species_list)}\n')
        f.write(f'names: {[str(x) for x in species_list.species.tolist()]}')


def prepare_data(arguments):
    """
    Prepare the annotation before getting splited
    :param args (args): Argument
    :return annotations (DataFrame): DataFrame with all the annotation to split
    """
    annotations = [] 
    background = []  
    for f in tqdm(glob.glob(os.path.join(arguments.path_to_data, '*.txt')), desc="Processing", 
                                                            ascii='░▒▓█'):
        file_annotation = pd.read_csv(f, sep=' ', names=['species', 'x', 'y', 'w', 'h'])
        if len(file_annotation) == 0:
            background.append(f)
        else:
            file_annotation['file'] = f
            annotations.extend(file_annotation.to_dict(orient='records'))

    annotations = pd.DataFrame(annotations)
    annotations.species = annotations.species.astype(float)
    return annotations, background


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description='Split the annotation into train, val and test if needed')
    parser.add_argument('path_to_data', type=utils.arg_directory,
                        help='Path of the folder that contains the .txt (ending with labels/)')
    parser.add_argument('directory', type=utils.arg_directory,
                        help='Directory to which spectrogram and .txt files will be'
                                         'stored (different from -p)')
    parser.add_argument('-r', '--ratio', type=float,
                        default=0.7, help='Train Ratio (val = 1 - ratio)')
    parser.add_argument('--test', action='store_const', const=1,
                        help='Split into train/test/val. 1 - Ratio / 2 '
                        ' for test and same for validation', default=None)
    args = parser.parse_args()

    df, background = prepare_data(args)
    train, val = utils.split(df, 'train', args.ratio)
    
    saved_directory = os.path.join(args.directory, 'set')
    utils.create_directory(saved_directory)
    if args.test:
        val, test = utils.split(val, 'val', 0.5)
        export_split(args, [train, val, test], args.path_to_data,
                                 saved_directory)
        state, proposition = utils.get_set_info([train, val, test])
    else:
        export_split(args, [train, val], args.path_to_data, saved_directory)
        state, proposition = utils.get_set_info([train, val])

    print(f'\nYour dataset is {state} {proposition}')

    print(f'Train saved in {saved_directory}\n')
    print('To train your model, use the following command : \n')

    current_path = os.path.abspath(os.path.dirname( __file__ ))

    directory_path = os.path.join(current_path, saved_directory)

    yolo_path = os.path.join(current_path, 'yolov5/train.py')
    data_path = os.path.join(directory_path, 'custom_data.yaml')
    weights_path = os.path.join(current_path, 'yolov5/weights/yolov5l.pt')
    hyp_path = os.path.join(current_path, 'custom_hyp.yaml')

    command = f'python {yolo_path} --data {data_path} --imgsz 640 --epochs 100 --weights {weights_path} --hyp {hyp_path} --cache'
    print(command,'\n')
    if len(background) == 0:
        print('\u26A0\uFE0F   Be aware that it is recommended to have background images that',
        'represents 10% of your dataset. If you do not have background, use the script "get_spectrogram.py"',
        'with --background arguments. Comptue on recordings that contains multiple type of noise...')
    else:
        utils.split_background(background, args)
        print(f'Your dataset contains {len(background)} images in background. It represents ',
        f'{(len(background)/len(df))*100} % of your dataset set. It is recommended to reach around',
        ' 10% for a good model.')