From bf7af8ca818820b4e271416110adf1a68ca24632 Mon Sep 17 00:00:00 2001 From: lamipaul <paulobest25@gmail.com> Date: Mon, 10 Jul 2023 11:09:57 +0200 Subject: [PATCH] fix archi (dont include frontend) --- new_specie/compute_embeddings.py | 5 +++-- new_specie/train_AE.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/new_specie/compute_embeddings.py b/new_specie/compute_embeddings.py index 2b18520..87e7e5f 100755 --- a/new_specie/compute_embeddings.py +++ b/new_specie/compute_embeddings.py @@ -21,9 +21,10 @@ args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) +frontend.to(device) encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4)) decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4)) -model = torch.nn.Sequential(frontend, encoder, decoder).to(device) +model = torch.nn.Sequential(encoder, decoder).to(device).eval() model.load_state_dict(torch.load(args.modelname)) df = pd.read_csv(args.detections) @@ -33,7 +34,7 @@ loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, a with torch.inference_mode(): encodings, idxs = [], [] for x, idx in tqdm(loader): - encoding = model[:2](x.to(device)) + encoding = model[0](frontend(x.to(device))) idxs.extend(idx) encodings.extend(encoding.cpu().detach()) idxs = np.array(idxs) diff --git a/new_specie/train_AE.py b/new_specie/train_AE.py index befa824..e762d99 100755 --- a/new_specie/train_AE.py +++ b/new_specie/train_AE.py @@ -27,7 +27,7 @@ assert args.bottleneck % (args.nMel//32 * 4) == 0, "Bottleneck size must be a mu frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4)) decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4)) -model = torch.nn.Sequential(frontend, encoder, decoder).to(device) +model = torch.nn.Sequential(encoder, decoder).to(device) # training / optimisation setup lr, wdL2, batch_size = 0.003, 0.0, 64 if torch.cuda.is_available() else 16 -- GitLab