import utils as u import models import numpy as np, pandas as pd, torch import umap from tqdm import tqdm import argparse, os torch.multiprocessing.set_sharing_strategy('file_system') parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute the auto-encoder embeddings of vocalizations once it was trained with train_AE.py") parser.add_argument('modelname', type=str, help='Filename of the AE weights (.stdc or .weights)') parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed") parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files") parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation") parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)") parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck') parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation") parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded") parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)") parser.set_defaults(feature=False) args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) frontend.to(device) encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4)) decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4)) model = torch.nn.Sequential(encoder, decoder).to(device).eval() model.load_state_dict(torch.load(args.modelname)) df = pd.read_csv(args.detections) print('Computing AE projections...') loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0)), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8, collate_fn=u.collate_fn) with torch.inference_mode(): encodings, idxs = [], [] for x, idx in tqdm(loader): encoding = model[0](frontend(x.to(device))) idxs.extend(idx) encodings.extend(encoding.cpu().detach()) idxs = np.array(idxs) encodings = np.stack(encodings) print('Computing UMAP projections...') X = umap.UMAP(n_jobs=-1).fit_transform(encodings) out_fn = f'encodings_{os.path.basename(args.detections).rsplit(".",1)[0]}.npy' print(f'Saving into {out_fn}') np.save(out_fn, {'encodings':encodings, 'idx':idxs, 'umap':X})