Skip to content
Snippets Groups Projects
Commit bf7af8ca authored by Paul Best's avatar Paul Best
Browse files

fix archi (dont include frontend)

parent 8f733869
Branches
No related tags found
No related merge requests found
...@@ -21,9 +21,10 @@ args = parser.parse_args() ...@@ -21,9 +21,10 @@ args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
frontend.to(device)
encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4)) encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4)) decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
model = torch.nn.Sequential(frontend, encoder, decoder).to(device) model = torch.nn.Sequential(encoder, decoder).to(device).eval()
model.load_state_dict(torch.load(args.modelname)) model.load_state_dict(torch.load(args.modelname))
df = pd.read_csv(args.detections) df = pd.read_csv(args.detections)
...@@ -33,7 +34,7 @@ loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, a ...@@ -33,7 +34,7 @@ loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, a
with torch.inference_mode(): with torch.inference_mode():
encodings, idxs = [], [] encodings, idxs = [], []
for x, idx in tqdm(loader): for x, idx in tqdm(loader):
encoding = model[:2](x.to(device)) encoding = model[0](frontend(x.to(device)))
idxs.extend(idx) idxs.extend(idx)
encodings.extend(encoding.cpu().detach()) encodings.extend(encoding.cpu().detach())
idxs = np.array(idxs) idxs = np.array(idxs)
......
...@@ -27,7 +27,7 @@ assert args.bottleneck % (args.nMel//32 * 4) == 0, "Bottleneck size must be a mu ...@@ -27,7 +27,7 @@ assert args.bottleneck % (args.nMel//32 * 4) == 0, "Bottleneck size must be a mu
frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4)) encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4)) decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
model = torch.nn.Sequential(frontend, encoder, decoder).to(device) model = torch.nn.Sequential(encoder, decoder).to(device)
# training / optimisation setup # training / optimisation setup
lr, wdL2, batch_size = 0.003, 0.0, 64 if torch.cuda.is_available() else 16 lr, wdL2, batch_size = 0.003, 0.0, 64 if torch.cuda.is_available() else 16
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment