Skip to content
Snippets Groups Projects
Commit 62192dcc authored by Paul Best's avatar Paul Best
Browse files

match weights and archi

parent 7807ef5f
No related branches found
No related tags found
No related merge requests found
...@@ -172,5 +172,5 @@ class STFT(torch.nn.Module): ...@@ -172,5 +172,5 @@ class STFT(torch.nn.Module):
if not self.complex: if not self.complex:
x = x.norm(p=2, dim=-1) x = x.norm(p=2, dim=-1)
# restore original batchsize and channels in case we mashed them # restore original batchsize and channels in case we mashed them
x = x.reshape((batchsize, channels, -1) + x.shape[2:]) #if channels > 1 else x.reshape((batchsize, -1) + x.shape[2:]) x = x.reshape((batchsize, channels, -1) + x.shape[2:]) if channels > 1 else x.reshape((batchsize, -1) + x.shape[2:])
return x return x
...@@ -21,19 +21,20 @@ class Dropout1d(nn.Module): ...@@ -21,19 +21,20 @@ class Dropout1d(nn.Module):
x = self.dropout(x) x = self.dropout(x)
return x.squeeze(-1) return x.squeeze(-1)
PHYSETER_NFEAT = 128 PHYSETER_NFEAT = 32
PHYSETER_KERNEL = 7 PHYSETER_KERNEL = 7
BALAENOPTERA_NFEAT = 128 BALAENOPTERA_NFEAT = 32
BALAENOPTERA_KERNEL = 5 BALAENOPTERA_KERNEL = 5
get = { get = {
'physeter': { 'physeter': {
'weights': 'stft_depthwise_ovs_128_k7_r1.stdc', 'weights':'stft_depthwise_ovs_64kHz_specBN_int16_newAnnot_randChan_32_k7_rBOMBYX2_prod.stdc',
'fs': 50000, 'fs': 64000,
'archi': nn.Sequential( 'archi': nn.Sequential(
STFT(512, 256), STFT(512, 256),
MelFilter(50000, 512, 64, 2000, 25000), MelFilter(64000, 512, 64, 2000, 25000),
Log1p(trainable=True), Log1p(trainable=True),
nn.BatchNorm1d(64),
depthwise_separable_conv1d(64, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2), depthwise_separable_conv1d(64, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2),
nn.BatchNorm1d(PHYSETER_NFEAT), nn.BatchNorm1d(PHYSETER_NFEAT),
nn.LeakyReLU(), nn.LeakyReLU(),
...@@ -43,20 +44,21 @@ get = { ...@@ -43,20 +44,21 @@ get = {
nn.LeakyReLU(), nn.LeakyReLU(),
Dropout1d(), Dropout1d(),
depthwise_separable_conv1d(PHYSETER_NFEAT, 1, PHYSETER_KERNEL, stride=2) depthwise_separable_conv1d(PHYSETER_NFEAT, 1, PHYSETER_KERNEL, stride=2)
), )
}, },
'balaenoptera': { 'balaenoptera': {
'weights': 'dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc', 'weights': 'dw_m64_brown_4kHz2_int16_32_k5_r6_specBN.stdc',
'fs': 200, 'fs': 4000,
'archi': nn.Sequential( 'archi': nn.Sequential(
STFT(256, 32), STFT(4096, 256),
MelFilter(200, 256, 128, 0, 100), MelFilter(4000, 4096, 64, 0, 100),
Log1p(trainable=True), Log1p(trainable=True),
depthwise_separable_conv1d(128, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2), nn.BatchNorm1d(64),
depthwise_separable_conv1d(64, BALAENOPTERA_NFEAT, BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2),
nn.BatchNorm1d(BALAENOPTERA_NFEAT), nn.BatchNorm1d(BALAENOPTERA_NFEAT),
nn.LeakyReLU(), nn.LeakyReLU(),
Dropout1d(), Dropout1d(),
depthwise_separable_conv1d(BALAENOPTERA_NFEAT, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2), depthwise_separable_conv1d(BALAENOPTERA_NFEAT, BALAENOPTERA_NFEAT, BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2),
nn.BatchNorm1d(BALAENOPTERA_NFEAT), nn.BatchNorm1d(BALAENOPTERA_NFEAT),
nn.LeakyReLU(), nn.LeakyReLU(),
Dropout1d(), Dropout1d(),
......
...@@ -74,7 +74,7 @@ out = pd.DataFrame(columns=['filename', 'offset', 'prediction']) ...@@ -74,7 +74,7 @@ out = pd.DataFrame(columns=['filename', 'offset', 'prediction'])
fns, offsets, preds = [], [], [] fns, offsets, preds = [], [], []
# forward the model on each batch # forward the model on each batch
with torch.no_grad(): with torch.inference_mode():
for x, meta in tqdm(loader, desc='Model inference'): for x, meta in tqdm(loader, desc='Model inference'):
x = x.to(device) x = x.to(device)
pred = special.expit(model(x).cpu().detach().numpy()) pred = special.expit(model(x).cpu().detach().numpy())
......
File added
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment