Skip to content
Snippets Groups Projects
Commit bf7af8ca authored by Paul Best's avatar Paul Best
Browse files

fix archi (dont include frontend)

parent 8f733869
No related branches found
No related tags found
No related merge requests found
......@@ -21,9 +21,10 @@ args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
frontend.to(device)
encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
model = torch.nn.Sequential(encoder, decoder).to(device).eval()
model.load_state_dict(torch.load(args.modelname))
df = pd.read_csv(args.detections)
......@@ -33,7 +34,7 @@ loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, a
with torch.inference_mode():
encodings, idxs = [], []
for x, idx in tqdm(loader):
encoding = model[:2](x.to(device))
encoding = model[0](frontend(x.to(device)))
idxs.extend(idx)
encodings.extend(encoding.cpu().detach())
idxs = np.array(idxs)
......
......@@ -27,7 +27,7 @@ assert args.bottleneck % (args.nMel//32 * 4) == 0, "Bottleneck size must be a mu
frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
model = torch.nn.Sequential(encoder, decoder).to(device)
# training / optimisation setup
lr, wdL2, batch_size = 0.003, 0.0, 64 if torch.cuda.is_available() else 16
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment