diff --git a/new_specie/models.py b/new_specie/models.py
index f7dcc80573136e04eed5839b15caf38a11c030c7..d5d0bb85352d1682ab8fcd32e2c112ff18bbc70b 100755
--- a/new_specie/models.py
+++ b/new_specie/models.py
@@ -1,7 +1,7 @@
 import torchvision.models as torchmodels
 from torch import nn
 import utils as u
-from filterbank import STFT, MelFilter, Log1p
+from filterbank import STFT, MelFilter, Log1p, MedFilt
 
 vgg16 = torchmodels.vgg16(weights=torchmodels.VGG16_Weights.DEFAULT)
 vgg16 = vgg16.features[:13]
@@ -10,6 +10,15 @@ for nm, mod in vgg16.named_modules():
         setattr(vgg16, nm,  nn.AvgPool2d(2 ,2))
 
 
+frontend_medfilt = lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
+  STFT(nfft, int((sampleDur*sr - nfft)/128)),
+  MelFilter(sr, nfft, n_mel, sr//nfft, sr//2),
+  Log1p(7, trainable=False),
+  nn.InstanceNorm2d(1),
+  MedFilt(),
+  u.Croper2D(n_mel, 128)
+)
+
 frontend = lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
   STFT(nfft, int((sampleDur*sr - nfft)/128)),
   MelFilter(sr, nfft, n_mel, sr//nfft, sr//2),
diff --git a/new_specie/train_AE.py b/new_specie/train_AE.py
index 15f63ebe99eb00e0d2a958b21476a1c60e6a673d..34100fa7f138c4f93a9fdbbd5c356c9fcec8acee 100755
--- a/new_specie/train_AE.py
+++ b/new_specie/train_AE.py
@@ -15,6 +15,8 @@ parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spec
 parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
 parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
 parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
+parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)")
+parser.set_defaults(feature=False)
 parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck')
 args = parser.parse_args()
 
@@ -22,7 +24,7 @@ args = parser.parse_args()
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 assert args.nMel % 32 == 0 and args.nMel > 0, "nMel argument should be a multiple of 32"
 assert args.bottleneck % (args.nMel//32 * 4) == 0, "Bottleneck size must be a multiple of the last volume\'s size (nMel//32 * 4)"
-frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
+frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
 encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
 decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
 model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
@@ -36,7 +38,7 @@ loss_fun = torch.nn.MSELoss()
 
 # data loader
 df = pd.read_csv(args.detections)
-loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn)
+loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0)), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn)
 
 modelname = f'{args.detections[:-4]}_AE_{args.bottleneck}_mel{args.nMel}.stdc'
 step, writer = 0, SummaryWriter('runs/'+modelname)
diff --git a/new_specie/utils.py b/new_specie/utils.py
index 09acb635fbd5b8f5b07d63910e52731cf8d5cb38..9697d45954936796fef6ce48982c200a0c4d4944 100755
--- a/new_specie/utils.py
+++ b/new_specie/utils.py
@@ -56,4 +56,4 @@ class Croper2D(nn.Module):
         super(Croper2D, self).__init__()
         self.shape = shape
     def forward(self, x):
-        return x[:,:,:self.shape[0],:self.shape[1]]
+        return x[:,:,:self.shape[0],(x.shape[-1] - self.shape[1])//2:-(x.shape[-1] - self.shape[1])//2]