diff --git a/frontend.py b/frontend.py index aa50b4b4dc7681443cd231df6221787d05a3466b..6465b1d30f8fb0fa34435f5a7c8b27ddeaaa612a 100644 --- a/frontend.py +++ b/frontend.py @@ -172,5 +172,5 @@ class STFT(torch.nn.Module): if not self.complex: x = x.norm(p=2, dim=-1) # restore original batchsize and channels in case we mashed them - x = x.reshape((batchsize, channels, -1) + x.shape[2:]) #if channels > 1 else x.reshape((batchsize, -1) + x.shape[2:]) + x = x.reshape((batchsize, channels, -1) + x.shape[2:]) if channels > 1 else x.reshape((batchsize, -1) + x.shape[2:]) return x diff --git a/models.py b/models.py index b046cafaa3bca3f1592ebe06fc39d17894e36fbf..a4291db2ec3d62a2a797fa80db879f93f7d07883 100644 --- a/models.py +++ b/models.py @@ -21,19 +21,20 @@ class Dropout1d(nn.Module): x = self.dropout(x) return x.squeeze(-1) -PHYSETER_NFEAT = 128 +PHYSETER_NFEAT = 32 PHYSETER_KERNEL = 7 -BALAENOPTERA_NFEAT = 128 +BALAENOPTERA_NFEAT = 32 BALAENOPTERA_KERNEL = 5 get = { 'physeter': { - 'weights': 'stft_depthwise_ovs_128_k7_r1.stdc', - 'fs': 50000, + 'weights':'stft_depthwise_ovs_64kHz_specBN_int16_newAnnot_randChan_32_k7_rBOMBYX2_prod.stdc', + 'fs': 64000, 'archi': nn.Sequential( STFT(512, 256), - MelFilter(50000, 512, 64, 2000, 25000), + MelFilter(64000, 512, 64, 2000, 25000), Log1p(trainable=True), + nn.BatchNorm1d(64), depthwise_separable_conv1d(64, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2), nn.BatchNorm1d(PHYSETER_NFEAT), nn.LeakyReLU(), @@ -43,20 +44,21 @@ get = { nn.LeakyReLU(), Dropout1d(), depthwise_separable_conv1d(PHYSETER_NFEAT, 1, PHYSETER_KERNEL, stride=2) - ), + ) }, 'balaenoptera': { - 'weights': 'dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc', - 'fs': 200, + 'weights': 'dw_m64_brown_4kHz2_int16_32_k5_r6_specBN.stdc', + 'fs': 4000, 'archi': nn.Sequential( - STFT(256, 32), - MelFilter(200, 256, 128, 0, 100), + STFT(4096, 256), + MelFilter(4000, 4096, 64, 0, 100), Log1p(trainable=True), - depthwise_separable_conv1d(128, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2), + nn.BatchNorm1d(64), + depthwise_separable_conv1d(64, BALAENOPTERA_NFEAT, BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2), nn.BatchNorm1d(BALAENOPTERA_NFEAT), nn.LeakyReLU(), Dropout1d(), - depthwise_separable_conv1d(BALAENOPTERA_NFEAT, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2), + depthwise_separable_conv1d(BALAENOPTERA_NFEAT, BALAENOPTERA_NFEAT, BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2), nn.BatchNorm1d(BALAENOPTERA_NFEAT), nn.LeakyReLU(), Dropout1d(), diff --git a/run_CNN.py b/run_CNN.py index e15d1df2ecd14cb2521f8693f659165caf8e506c..07daf2bd919eeaf1531497b1b00bab4b9cd8e9f4 100644 --- a/run_CNN.py +++ b/run_CNN.py @@ -74,7 +74,7 @@ out = pd.DataFrame(columns=['filename', 'offset', 'prediction']) fns, offsets, preds = [], [], [] # forward the model on each batch -with torch.no_grad(): +with torch.inference_mode(): for x, meta in tqdm(loader, desc='Model inference'): x = x.to(device) pred = special.expit(model(x).cpu().detach().numpy()) diff --git a/weights/dw_m64_brown_4kHz2_int16_32_k5_r6_specBN.stdc b/weights/dw_m64_brown_4kHz2_int16_32_k5_r6_specBN.stdc new file mode 100644 index 0000000000000000000000000000000000000000..30f0504e6d652efe657238c4013349fb0d984888 Binary files /dev/null and b/weights/dw_m64_brown_4kHz2_int16_32_k5_r6_specBN.stdc differ diff --git a/weights/stft_depthwise_ovs_64kHz_specBN_int16_newAnnot_randChan_32_k7_rBOMBYX2_prod.stdc b/weights/stft_depthwise_ovs_64kHz_specBN_int16_newAnnot_randChan_32_k7_rBOMBYX2_prod.stdc new file mode 100644 index 0000000000000000000000000000000000000000..f2f0cfc38ac186661671d01f916496b52532ff8b Binary files /dev/null and b/weights/stft_depthwise_ovs_64kHz_specBN_int16_newAnnot_randChan_32_k7_rBOMBYX2_prod.stdc differ