from metadata import species
import pandas as pd, numpy as np, os, argparse, librosa, parselmouth, mir_eval
from glob import glob
from tqdm import tqdm
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import crepe, pesto, torchcrepe, torch, tensorflow as tf, basic_pitch.inference, basic_pitch.constants

device, batch_size = ('cuda', 64) if torch.cuda.is_available() else ('cpu', 1)

# LOAD MODELS
basic_pitch_model = tf.saved_model.load(str(basic_pitch.ICASSP_2022_MODEL_PATH))

tcrepe_model = torchcrepe.Crepe('full').eval().to(device)
tcrepe_model.load_state_dict(torch.load('/home/paul.best/.local/lib/python3.9/site-packages/torchcrepe/assets/full.pth', map_location='cuda'))

cents_mapping = np.linspace(0, 7180, 360) + 1997.3794084376191

def run_tcrepe(model, sig, fs, dt):
    generator = torchcrepe.core.preprocess(torch.tensor(sig).unsqueeze(0), fs, hop_length=dt*fs if fs != torchcrepe.SAMPLE_RATE else int(dt*fs),\
                                           batch_size=batch_size, device=device, pad=False)
    with torch.no_grad():
        preds = np.vstack([model(frames).cpu().numpy() for frames in generator])
    f0 = 10 * 2 ** (crepe.core.to_local_average_cents(preds) / 1200)
    confidence = np.max(preds, axis=1)
    time = np.arange(torchcrepe.WINDOW_SIZE/2, len(sig)/fs*torchcrepe.SAMPLE_RATE - torchcrepe.WINDOW_SIZE/2 + 1e-9, dt*torchcrepe.SAMPLE_RATE) / torchcrepe.SAMPLE_RATE
    return time, f0, confidence

# PARSE ARGUMENTS
parser = argparse.ArgumentParser()
parser.add_argument('specie', type=str, help="Species to treat specifically", default=None)
parser.add_argument('--overwrite', type=bool, help="Overwrite previous pedictions", default=False)
parser.add_argument('--quick', type=bool, help="Skip pyin and crepe to make things quick", default=False)
parser.add_argument('--split', type=int, help="Section to test on between 0 and 4", default=None)
args = parser.parse_args()

algos = ['praat_f0','pyin_f0','crepe_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0', 'pesto_ftoth_f0']
quick_algos = ['praat_f0','pyin_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0']

if args.overwrite:
    print('Overwriting previous results')

# Iterate over species, then files, then run each algorithm and save the predictions
for specie in species if args.specie =='all' else args.specie.split(' '):
    wavpath, FS, nfft, downsample, step = species[specie].values()
    dt = round(nfft * step / FS * downsample, 3) # winsize / 8
    # Load species specific pre-trained models
    tcrepe_ftoth_model, tcrepe_ftsp_model = None, None
    if os.path.isfile(f'crepe_ft/model_only-{args.split}_{specie}.pth'):
        tcrepe_ftsp_model = torchcrepe.Crepe('full').eval().to(device)
        tcrepe_ftsp_model.load_state_dict(torch.load(f'crepe_ft/model_only-{args.split}_{specie}.pth', map_location=device))
    if os.path.isfile(f'crepe_ft/model_omit_{specie}.pth'):
        tcrepe_ftoth_model = torchcrepe.Crepe('full').eval().to(device)
        tcrepe_ftoth_model.load_state_dict(torch.load(f'crepe_ft/model_omit_{specie}.pth', map_location=device))
    # initialise the file list to iterate on
    fns = glob(wavpath)
    if type(args.split) == int:
        fns = fns[int(len(fns)/5*args.split) : int(len(fns)/5*(args.split+1))]
    # iterate over files
    for fn in tqdm(fns, desc=specie):
        if args.overwrite or not os.path.isfile(f'{fn[:-4]}_preds.csv') or os.path.getsize(f'{fn[:-4]}_preds.csv') < 300:
            # load original annotation file
            annot = pd.read_csv(f'{fn[:-4]}.csv').drop_duplicates(subset='Time')
            # add a 0 at starts and ends for large gaps to avoid interpolating between vocalisations
            med_diff = annot.Time.diff().median()
            rgaps, lgaps = annot.Time[annot.Time.diff() > med_diff*4], annot.Time[annot.Time.diff(-1) < - med_diff * 4]
            annot = pd.concat([annot, pd.DataFrame({'Time':np.concatenate([lgaps+med_diff, rgaps-med_diff]), 'Freq':[0]*(len(lgaps)+len(rgaps))})]).sort_values('Time')
            # load the waveform and create the dataframe for storing predictions
            sig, fs = librosa.load(fn, sr=FS)
            out = pd.DataFrame({'time':np.arange(nfft/fs/2, (len(sig) - nfft/2)/fs, dt / downsample)})
            mask = ((out.time > annot.Time.min())&(out.time < annot.Time.max()))
            out.loc[mask, 'annot'] = mir_eval.melody.resample_melody_series(annot.Time, annot.Freq, annot.Freq>0, out[mask].time, verbose=False)[0]
        else:
            out = pd.read_csv(f'{fn[:-4]}_preds.csv')
            for algo in algos: # drop a column if all values are None
                if algo in out.columns and out[algo].isna().all():
                    out.drop(algo, axis=1, inplace=True)

            # check if everything has already been computed, and if yes skip the file
            if pd.Series(algos).isin(out.columns).all() or (args.quick and pd.Series(quick_algos).isin(out.columns).all()):
               continue
            sig, fs = librosa.load(fn, sr=FS)

        out.time *= downsample
        fs /= downsample

        if not 'praat_f0' in out.columns: # PRAAT
            sndpitches0 = parselmouth.Sound(sig, fs).to_pitch_ac(pitch_floor=27.5, pitch_ceiling=fs//2, voicing_threshold=0.20, time_step=dt)
            time, f0, confidence = sndpitches0.xs(), sndpitches0.selected_array['frequency'], sndpitches0.selected_array['strength']
            mask = ((out.time>=time[0])&(out.time<=time[-1]))
            out.loc[mask, 'praat_f0'], out.loc[mask, 'praat_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out.loc[mask, 'time'])
            out.praat_f0 *= downsample

        if not 'pyin_f0' in out.columns: # PYIN
            f0, voiced, prob = librosa.pyin(sig, sr=fs, fmin=27.5, fmax=fs//2, frame_length=nfft, hop_length=int(fs*dt), center=False)
            out['pyin_f0'], out['pyin_conf'] = f0[:len(out)], prob[:len(out)]
            out.pyin_f0 *= downsample

        if not 'tcrepe_f0' in out.columns: # torch crepe out-of-the-box
            time, f0, confidence = run_tcrepe(tcrepe_model, sig, fs, dt)
            mask = ((out.time > time[0])&(out.time < time[-1]))
            out.loc[mask, 'tcrepe_f0'], out.loc[mask, 'tcrepe_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out[mask].time)
            out.tcrepe_f0 *= downsample

        if not 'tcrepe_ftsp_f0' in out.columns and tcrepe_ftsp_model: # torch crepe finetuned on the target species
            time, f0, confidence = run_tcrepe(tcrepe_ftsp_model, sig, fs, dt)
            mask = ((out.time > time[0])&(out.time < time[-1]))
            out.loc[mask, 'tcrepe_ftsp_f0'], out.loc[mask, 'tcrepe_ftsp_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out[mask].time)
            out.tcrepe_ftsp_f0 *= downsample

        if not 'tcrepe_ftoth_f0' in out.columns and tcrepe_ftoth_model: # torch crepe finetuned on other species than the target
            time, f0, confidence = run_tcrepe(tcrepe_ftoth_model, sig, fs, dt)
            mask = ((out.time > time[0])&(out.time < time[-1]))
            out.loc[mask, 'tcrepe_ftoth_f0'], out.loc[mask, 'tcrepe_ftoth_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out[mask].time)
            out.tcrepe_ftoth_f0 *= downsample

        if not args.quick and not 'crepe_f0' in out.columns: # CREPE out-of-the-box tensorflow
            time, f0, confidence, activation = crepe.predict(sig, fs, step_size=int(dt*1e3), center=False, verbose=0) # step_size in ms
            out['crepe_f0'], out['crepe_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out.time)
            out.crepe_f0 *= downsample

        if not 'basic_f0' in out.columns: # basic_pitch
            S = basic_pitch.inference.run_inference(fn, basic_pitch_model)['contour']
            time = np.arange(len(S)) * basic_pitch.constants.FFT_HOP / basic_pitch.constants.AUDIO_SAMPLE_RATE
            f0 = basic_pitch.constants.FREQ_BINS_CONTOURS[np.argmax(S, axis=1)]
            confidence = np.max(S, axis=1)
            out['basic_f0'], out['basic_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out.time)
            out.basic_f0 *= downsample

        if not 'pesto_f0' in out.columns: # pesto out-of-the-box
            try:
                time, f0, confidence, activation = pesto.predict(torch.tensor(sig).unsqueeze(0), fs, step_size=int(dt*1e3), convert_to_freq=True) # step_size in ms
                out['pesto_f0'], out['pesto_conf'] = mir_eval.melody.resample_melody_series(time/1000, f0[0], confidence.numpy(), out.time, verbose=False)
                out.pesto_f0 *= downsample
            except:
                out['pesto_f0'], out['pesto_conf'] = None, None

        if not 'pesto_ft_f0' in out.columns and os.path.isfile(f'pesto_ft/{specie}.pth'): # pesto finetuned
            try:
                time, f0, confidence, activation = pesto.predict(torch.tensor(sig).unsqueeze(0), fs, model_name=f'pesto_ft/{specie}.pth', step_size=int(dt*1e3), convert_to_freq=True) # step_size in ms
                out['pesto_ft_f0'], out['pesto_ft_conf'] = mir_eval.melody.resample_melody_series(time/1000 + 1e-6, f0[0], confidence.numpy(), out.time, verbose=False)
                out.pesto_ft_f0 *= downsample
            except Exception as inst:
                out['pesto_ft_f0'], out['pesto_ft_conf'] = None, None

        out.time /= downsample
        out.to_csv(f'{fn[:-4]}_preds.csv', index=False)