import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from scipy.special import expit as sigmoid
import scipy.signal as signal
import soundfile
import parselmouth
from matplotlib.transforms import Bbox
from tqdm import tqdm
import os
import argparse

def main(inputFolder, output, extend, pred_threshold, display, saveimg, specific_time):
    for filename in os.listdir(inputFolder):
        if os.path.isfile(output+filename[:-6]+'.npy'):
            print(filename+' Hey I already treated this file ! find the output at '+output+filename[:-6]+'.npy')
        elif os.path.getsize(os.path.join(inputFolder,filename))<5:
            print(os.path.join(inputFolder,filename)+' This file is empty duh... ')
        elif not filename.endswith('.preds'):
            print(filename+' I\'m so sorry but I only work with .preds files !')
        else :
            preds = np.load(os.path.join(inputFolder,filename),allow_pickle=True)
            #load the preds into a pandas dataframe
            ok = []
            for key, value in preds.items():
                if 'h1' in key:
                    date = key.split('_')[3]+key.split('_')[4]
                    tempdic = {'filename':key[:-8]+'.flac','date':pd.Timestamp(pd.to_datetime(date)),'prob_h1':sigmoid(value)}
                    for h in ['h2','h3','h4','h5']:
                        probs =  sigmoid(preds[key[:-7]+h+'.flac'])
                        tempdic['prob_'+h]=probs
                    ok.append(tempdic)
            df = pd.DataFrame(ok)
            df.set_index('date', inplace=True)
            df.sort_values('date',inplace=True)

            hydrotochan = {'h1':3,'h2':4,'h3':5,'h4':8,'h5':7} #starts at 1 not 0 !
            # fft window length : 1024, fft window hop size : 315
            # pred window length : 103, pred window hop size : 9
            pred_offset = ((103-1)*315 + 1024)/2 # pred offset is a pred window divided by 2
            pred_hop_size = 9*315 
            orcatime = np.arange(922) * pred_hop_size/22050 + pred_offset/22050 #for orcalab, we have 922 predictions per file

            call_count = 0
            call_missed = 0
            out_dic = {}
            if specific_time:
                df = df[specific_time]
            # iterate through the dates
            for index, row in tqdm(df.iterrows(), total=len(df), desc=filename):
                # check throughout the hydros if there is any high prediction before loading the soundfile !
                probs_sum = sum([np.asarray(row['prob_'+hydro] > pred_threshold).sum() for hydro in hydrotochan])

                if probs_sum > 1:
                    soundpath = "/nfs/NAS3/SABIOD/SITE/OrcaLab/YEAR_MONTH_DAY/{}/{:0>2}/{:0>2}".format(row.name.year, row.name.month, row.name.day)+"/"+row['filename']
                    try:
                        sound = parselmouth.Sound(soundpath)
                    except:
                        print('error opening '+soundpath)
                        continue
                    for hydro in hydrotochan:
                        #print(hydro, index)
                        binary_probs = np.atleast_1d(row['prob_'+hydro] > pred_threshold)
                        if binary_probs.sum()>1:
                            monosound = sound.extract_channel(hydrotochan[hydro])
                            sndpitches0 = monosound.to_pitch_ac(pitch_floor=1000, pitch_ceiling=2500,voicing_threshold=0.20)
                            pitchtime = sndpitches0.xs()
                            pitch = sndpitches0.selected_array['frequency']

                            if extend > 0 :
                                #extend each call by x preds on the right side
                                orcas = np.append(binary_probs[:extend],
                                                  [np.clip(binary_probs[i-extend:i+1].sum(),a_max=1, a_min=0) for i in range(len(binary_probs[extend:]))])
                            else:
                                orcas = binary_probs
                            #go through all predictions
                            predid = 0
                            while predid < len(orcas):
                                isOrca = orcas[predid]
                                if isOrca:
                                    lencall = 1
                                    while (predid+lencall)<len(orcas) and orcas[predid+lencall] :
                                        lencall+=1

                                    t1 = orcatime[predid]
                                    t2 = orcatime[predid+lencall-1]
                                    callmask = ((pitchtime>t1 )&(pitchtime<t2)&(pitch>0))
                                    call = pitch[callmask]
                                    calltime = pitchtime[callmask]

                                    # if Parselmouth didn't find any pitch within the call or the call is less than 3 predictions long
                                    if len(call) == 0 or lencall < 3 :
                                        #print('Missed call n°', call_missed, ' at :', index, hydro, t1, t2)
                                        call_missed+=1
                                    else : #let's clean the call of outliers
                                        cleancall = [call[0]]
                                        cleancalltime = [calltime[0]]
                                        for k, v in enumerate(call[1:]):
                                            # if the next pitch is not further than 200Hz, save it
                                            if abs(v-cleancall[-1]) < 200:
                                                cleancall.append(v)
                                                cleancalltime.append(calltime[k+1])
                                            # else if we did not save any pitch for more than 40% of the call :
                                            elif (cleancalltime[-1]==calltime).argmax() < k-len(call)*.40:
                                                # start from scratch right after the big gap
                                                start = (calltime==cleancalltime[-1]).argmax()+1
                                                cleancall = [call[start]]
                                                cleancalltime = [calltime[start]]
                                                for k, v in enumerate(call[start+1:]):
                                                    if abs(v-cleancall[-1]) < 200:
                                                        cleancall.append(v)
                                                        cleancalltime.append(calltime[k+start])
                                                break;

                                        if display or saveimg:
                                            sig, sr = soundfile.read(soundpath)
                                            plt.close()
                                            fig,ax = plt.subplots(1)
                                            ax.set_title(row['filename']+' '+hydro)
                                            ax.specgram(sig.T[hydrotochan[hydro]-1], window=signal.windows.hamming(1024), Fs=sr, NFFT=1024, noverlap=890, cmap='viridis')
                                            ax.scatter(calltime, call, c='r', s=1, label='orginal pitch')
                                            ax.scatter(cleancalltime, cleancall, c='b', s=1, label='filtered pitch')
                                            ax.set_ylim((np.clip(min(call)-2000, a_min=0, a_max=20000), np.clip(max(call)+2000, a_min=0, a_max=20000)))
                                            ax.set_xlim((t1-0.5, t2+0.5))
                                            ax.plot([t1, t2],[min(call)-800, min(call)-800], c='k', label='model\'s prediction')
                                            ax.legend()
                                            #ax.add_patch(Rectangle((t1, call.min()), t2-t1, call.max()-call.min(), linewidth=1, edgecolor='k', facecolor='none'))
                                            if display:
                                                plt.show()
                                            if saveimg:
                                                plt.savefig(saveimg+'/'+str(call_count)+'.png')

                                        out_dic[call_count] = {'filename':row['filename'], 'hydro':hydro, 'call':cleancall, 'calltime':cleancalltime}
                                        #print('Saved call ', call_count)

                                        call_count += 1

                                predid += lencall if isOrca else 1

            np.save(output+str(filename[:-6]), out_dic)
            print('saved '+str(len(out_dic))+' calls into '+output+filename[:-6])


if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                     description="Once your best detector has outlined your favorite animal's calls, this script allows you to extract their pitch (using the Praat technology https://parselmouth.readthedocs.io/en/stable/ , based on the Paul Boersma (1993): \"Accurate short-term analysis of the fundamental frequency and the harmonics-to-noise ratio of a sampled sound.\" Proceedings of the Institute of Phonetic Sciences 17: 97–110. University of Amsterdam.) The dictionnary storing the pitch of the given calls will be saved in a .npy in the folder of your choice")

    parser.add_argument("file", type=str, help="Folder of .preds (dictionnaries 'filename': model's prediction array)")
    parser.add_argument("--output", type=str, default='./', help="Output folder, ")
    parser.add_argument("--extend", type=int, default=0, help="Number of preds to extend the calls on the right side")
    parser.add_argument("--pred_threshold", type=float, default=.9, help="Threshold for considering an prediction true (between 0 and 1)")
    parser.add_argument("--display", type=bool, default=False, help="Boolean for displaying calls and their pitch")
    parser.add_argument("--saveimg", type=str, default=None, help="Path for saving vocalisation images")
    parser.add_argument("--specific_time", type=str, default=None, help="Date for running at specific time using pandas date indexing, format : YYYY-MM-DD HH:MM:SS")
    args = parser.parse_args()

    main(args.file, args.output, args.extend, args.pred_threshold, args.display, args.saveimg, args.specific_time)