diff --git a/new_specie/compute_embeddings.py b/new_specie/compute_embeddings.py
index 696b73550e808e2baf1da6c6f7b8064e08c1abbb..93f344d5b9c1bcacd522451e75083b5c24e0c568 100755
--- a/new_specie/compute_embeddings.py
+++ b/new_specie/compute_embeddings.py
@@ -15,10 +15,12 @@ parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands fo
 parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck')
 parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
 parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
+parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)")
+parser.set_defaults(feature=False)
 args = parser.parse_args()
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
+frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
 encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
 decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
 model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
@@ -26,7 +28,7 @@ model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
 df = pd.read_csv(args.detections)
 
 print('Computing AE projections...')
-loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur, channel=1), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8)
+loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0), channel=1), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8, collate_fn=u.collate_fn)
 with torch.inference_mode():
     encodings, idxs = [], []
     for x, idx in tqdm(loader):
@@ -38,4 +40,4 @@ encodings = np.stack(encodings)
 
 print('Computing UMAP projections...')
 X = umap.UMAP(n_jobs=-1).fit_transform(encodings)
-np.save(f'encodings_{args.detections[:-4]}_{args.modelname.split('.')[0]}.npy', {'encodings':encodings, 'idx':idxs, 'umap':X})
+np.save(f'encodings_{args.detections[:-4]}_{args.modelname.split(".")[0]}.npy', {'encodings':encodings, 'idx':idxs, 'umap':X})