import os
import torch
import models
from scipy import signal
import soundfile as sf
from torch.utils import data
import numpy as np
import pandas as pd
from tqdm import tqdm
import argparse

parser = argparse.ArgumentParser(description="Run this script to use a CNN for inference on a folder of audio files.")
parser.add_argument('audio_folder', type=str, help='Path of the folder with audio files to process')
parser.add_argument('specie', type=str, help='Target specie to detect', choices=['megaptera', 'delphinid', 'orcinus', 'physeter', 'balaenoptera'])
parser.add_argument('pred_fn', type=str, help='Filename for the output table containing model predictions')
parser.add_argument('-lensample', type=float, help='Length of the signal excerpts to process (sec)', default=5),
parser.add_argument('-batch_size', type=int, help='Amount of samples to process at a time', default=32),
parser.add_argument('-maxPool', help='Wether to keep only the maximal prediction of a sample or the full sequence', action='store_true'),
parser.add_argument('-no-maxPool', dest='maxPool', action='store_false')
parser.set_defaults(maxPool=True)
args = parser.parse_args()

meta_model = {
    'delphinid': {
        'stdc': 'sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc',
        'fs': 96000
    },
    'megaptera': {
        'stdc': 'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc',
        'fs': 11025
    },
    'orcinus': '',
    'physeter': {
        'stdc': 'stft_depthwise_ovs_128_k7_r1.stdc',
        'fs': 50000
    },
    'balaenoptera': {
        'stdc': 'dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc',
        'fs': 200
    }
}[args.specie]


def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return data.dataloader.default_collate(batch) if len(batch) > 0 else None

norm = lambda arr: (arr - np.mean(arr) ) / np.std(arr)

class Dataset(data.Dataset):
    def __init__(self, folder, fs, lensample):
        super(Dataset, self)
        print('initializing dataset...')
        self.samples = []
        for fn in os.listdir(folder):
            try:
                duration = sf.info(folder+fn).duration
            except:
                print(f'Skipping {fn} (unable to read as audio)')
                continue
            self.samples.extend([{'fn':fn, 'offset':offset} for offset in np.arange(0, duration+.01-lensample, lensample)])
        self.fs, self.folder, self.lensample = fs, folder, lensample

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        fs = sf.info(self.folder+sample['fn']).samplerate
        try:
            sig, fs = sf.read(self.folder+sample['fn'], start=int(sample['offset']*fs), stop=int((sample['offset']+self.lensample)*fs), always_2d=True)
        except:
            print('Failed loading '+sample['fn'])
            return None
        sig = sig[:,0]
        if fs != self.fs:
            sig = signal.resample(sig, self.lensample*self.fs)
        sig = norm(sig)
        return torch.tensor(sig).float(), sample


# prepare model
model = models.get[args.specie]
model.load_state_dict(torch.load(f"weights/{meta_model['stdc']}"))
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# prepare data loader and output storage for predictions
loader = data.DataLoader(Dataset(args.audio_folder, meta_model['fs'], args.lensample), batch_size=args.batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4)
out = pd.DataFrame(columns=['filename', 'offset', 'prediction'])
fns, offsets, preds = [], [], []
if len(loader) == 0:
    print('Unable to open any audio file in the given folder')
    exit()

with torch.no_grad():
    for x, meta in tqdm(loader):
        x = x.to(device)
        pred = model(x).cpu().detach().numpy()
        if args.maxPool:
            pred = pred.max(axis=-1).reshape(len(x))
        else:
            pred = pred.reshape(len(x), -1)
        preds.extend(pred)
        fns.extend(meta['fn'])
        offsets.extend(meta['offset'].numpy())

out.filename, out.offset, out.prediction = fns, offsets, preds
out.to_pickle(args.pred_fn)