import utils as u
import models
import numpy as np, pandas as pd, torch
import umap
from tqdm import tqdm
import argparse, os
torch.multiprocessing.set_sharing_strategy('file_system')

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute the auto-encoder embeddings of vocalizations once it was trained with train_AE.py")
parser.add_argument('modelname', type=str, help='Filename of the AE weights (.stdc or .weights)')
parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation")
parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck')
parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)")
parser.set_defaults(feature=False)
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
frontend.to(device)
encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
model = torch.nn.Sequential(encoder, decoder).to(device).eval()
model.load_state_dict(torch.load(args.modelname))

df = pd.read_csv(args.detections)

print('Computing AE projections...')
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0)), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8, collate_fn=u.collate_fn)
with torch.inference_mode():
    encodings, idxs = [], []
    for x, idx in tqdm(loader):
        encoding = model[0](frontend(x.to(device)))
        idxs.extend(idx)
        encodings.extend(encoding.cpu().detach())
idxs = np.array(idxs)
encodings = np.stack(encodings)

print('Computing UMAP projections...')
X = umap.UMAP(n_jobs=-1).fit_transform(encodings)
out_fn = f'encodings_{os.path.basename(args.detections).rsplit(".",1)[0]}.npy'
print(f'Saving into {out_fn}')
np.save(out_fn, {'encodings':encodings, 'idx':idxs, 'umap':X})