diff --git a/new_specie/compute_embeddings.py b/new_specie/compute_embeddings.py index 696b73550e808e2baf1da6c6f7b8064e08c1abbb..93f344d5b9c1bcacd522451e75083b5c24e0c568 100755 --- a/new_specie/compute_embeddings.py +++ b/new_specie/compute_embeddings.py @@ -15,10 +15,12 @@ parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands fo parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck') parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation") parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded") +parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)") +parser.set_defaults(feature=False) args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) +frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4)) decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4)) model = torch.nn.Sequential(frontend, encoder, decoder).to(device) @@ -26,7 +28,7 @@ model = torch.nn.Sequential(frontend, encoder, decoder).to(device) df = pd.read_csv(args.detections) print('Computing AE projections...') -loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur, channel=1), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8) +loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0), channel=1), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8, collate_fn=u.collate_fn) with torch.inference_mode(): encodings, idxs = [], [] for x, idx in tqdm(loader): @@ -38,4 +40,4 @@ encodings = np.stack(encodings) print('Computing UMAP projections...') X = umap.UMAP(n_jobs=-1).fit_transform(encodings) -np.save(f'encodings_{args.detections[:-4]}_{args.modelname.split('.')[0]}.npy', {'encodings':encodings, 'idx':idxs, 'umap':X}) +np.save(f'encodings_{args.detections[:-4]}_{args.modelname.split(".")[0]}.npy', {'encodings':encodings, 'idx':idxs, 'umap':X})