diff --git a/new_specie/compute_embeddings.py b/new_specie/compute_embeddings.py
index 2b18520d395e79ce5e4ab74b1b6c6b1e024f309a..87e7e5f24f117419e220e481a3e9b4a907234e3f 100755
--- a/new_specie/compute_embeddings.py
+++ b/new_specie/compute_embeddings.py
@@ -21,9 +21,10 @@ args = parser.parse_args()
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
+frontend.to(device)
 encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
 decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
-model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
+model = torch.nn.Sequential(encoder, decoder).to(device).eval()
 model.load_state_dict(torch.load(args.modelname))
 
 df = pd.read_csv(args.detections)
@@ -33,7 +34,7 @@ loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, a
 with torch.inference_mode():
     encodings, idxs = [], []
     for x, idx in tqdm(loader):
-        encoding = model[:2](x.to(device))
+        encoding = model[0](frontend(x.to(device)))
         idxs.extend(idx)
         encodings.extend(encoding.cpu().detach())
 idxs = np.array(idxs)
diff --git a/new_specie/train_AE.py b/new_specie/train_AE.py
index befa82419159f18cf37fc127d7e4ffd908073d0b..e762d998dc6717047fb3de17d886fe1ec989a5ee 100755
--- a/new_specie/train_AE.py
+++ b/new_specie/train_AE.py
@@ -27,7 +27,7 @@ assert args.bottleneck % (args.nMel//32 * 4) == 0, "Bottleneck size must be a mu
 frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
 encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
 decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
-model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
+model = torch.nn.Sequential(encoder, decoder).to(device)
 
 # training / optimisation setup
 lr, wdL2, batch_size = 0.003, 0.0, 64 if torch.cuda.is_available() else 16