diff --git a/models.py b/models.py
index 75660662c698d0c32e6f7f09ea74a560f8caafa6..730b943f1e1d2aa0b8f0f3a9ffd7a0c23039d28f 100644
--- a/models.py
+++ b/models.py
@@ -8,7 +8,7 @@ class depthwise_separable_conv1d(nn.Module):
         self.depthwise = nn.Conv1d(nin, nin, kernel_size=kernel, padding=padding, stride=stride, groups=nin)
         self.pointwise = nn.Conv1d(nin, nout, kernel_size=1)
     def forward(self, x):
-        out = self.depthwise(x)
+        out = self.depthwise(x.squeeze(1))
         out = self.pointwise(out)
         return out
 
@@ -27,105 +27,162 @@ BALAENOPTERA_NFEAT = 128
 BALAENOPTERA_KERNEL = 5
 
 get = {
-    'physeter' : nn.Sequential(
-        STFT(512, 256),
-        MelFilter(50000, 512, 64, 2000, 25000),
-        Log1p(),
-        depthwise_separable_conv1d(64, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2),
-        nn.BatchNorm1d(PHYSETER_NFEAT),
-        nn.LeakyReLU(),
-        Dropout1d(),
-        depthwise_separable_conv1d(PHYSETER_NFEAT, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2),
-        nn.BatchNorm1d(PHYSETER_NFEAT),
-        nn.LeakyReLU(),
-        Dropout1d(),
-        depthwise_separable_conv1d(PHYSETER_NFEAT, 1, PHYSETER_KERNEL, stride=2)
-    ),
-    'balaenoptera': nn.Sequential(
-        STFT(256, 32),
-        MelFilter(200, 256, 128, 0, 100),
-        Log1p(),
-        depthwise_separable_conv1d(128, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2),
-        nn.BatchNorm1d(BALAENOPTERA_NFEAT),
-        nn.LeakyReLU(),
-        Dropout1d(),
-        depthwise_separable_conv1d(BALAENOPTERA_NFEAT, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2),
-        nn.BatchNorm1d(BALAENOPTERA_NFEAT),
-        nn.LeakyReLU(),
-        Dropout1d(),
-        depthwise_separable_conv1d(BALAENOPTERA_NFEAT, 1, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2)
-    ),
-    'megaptera' : nn.Sequential(
-        nn.Sequential(
-            STFT(512, 64),
-            MelFilter(11025, 512, 64, 100, 3000),
-            PCENLayer(64)
+    'physeter': {
+        'weights': 'stft_depthwise_ovs_128_k7_r1.stdc',
+        'fs': 50000,
+        'archi': nn.Sequential(
+            STFT(512, 256),
+            MelFilter(50000, 512, 64, 2000, 25000),
+            Log1p(trainable=True),
+            depthwise_separable_conv1d(64, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2),
+            nn.BatchNorm1d(PHYSETER_NFEAT),
+            nn.LeakyReLU(),
+            Dropout1d(),
+            depthwise_separable_conv1d(PHYSETER_NFEAT, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2),
+            nn.BatchNorm1d(PHYSETER_NFEAT),
+            nn.LeakyReLU(),
+            Dropout1d(),
+            depthwise_separable_conv1d(PHYSETER_NFEAT, 1, PHYSETER_KERNEL, stride=2)
         ),
-        nn.Sequential(
-            nn.Conv2d(1, 32, 3, bias=False),
-            nn.BatchNorm2d(32),
-            nn.LeakyReLU(0.01),
-            nn.Conv2d(32, 32, 3,bias=False),
-            nn.BatchNorm2d(32),
-            nn.MaxPool2d(3),
-            nn.LeakyReLU(0.01),
-            nn.Conv2d(32, 32, 3, bias=False),
-            nn.BatchNorm2d(32),
-            nn.LeakyReLU(0.01),
-            nn.Conv2d(32, 32, 3, bias=False),
-            nn.BatchNorm2d(32),
-            nn.LeakyReLU(0.01),
-            nn.Conv2d(32, 64, (16, 3), bias=False),
-            nn.BatchNorm2d(64),
-            nn.MaxPool2d((1,3)),
-            nn.LeakyReLU(0.01),
-            nn.Dropout(p=.5),
-            nn.Conv2d(64, 256, (1, 9), bias=False),
-            nn.BatchNorm2d(256),
-            nn.LeakyReLU(0.01),
-            nn.Dropout(p=.5),
-            nn.Conv2d(256, 64, 1, bias=False),
-            nn.BatchNorm2d(64),
-            nn.LeakyReLU(0.01),
-            nn.Dropout(p=.5),
-            nn.Conv2d(64, 1, 1, bias=False)
+    },
+    'balaenoptera': {
+        'weights': 'dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc',
+        'fs': 200,
+        'archi': nn.Sequential(
+            STFT(256, 32),
+            MelFilter(200, 256, 128, 0, 100),
+            Log1p(trainable=True),
+            depthwise_separable_conv1d(128, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2),
+            nn.BatchNorm1d(BALAENOPTERA_NFEAT),
+            nn.LeakyReLU(),
+            Dropout1d(),
+            depthwise_separable_conv1d(BALAENOPTERA_NFEAT, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2),
+            nn.BatchNorm1d(BALAENOPTERA_NFEAT),
+            nn.LeakyReLU(),
+            Dropout1d(),
+            depthwise_separable_conv1d(BALAENOPTERA_NFEAT, 1, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2)
         )
-    ),
-    'delphinid' : nn.Sequential(
-        nn.Sequential(
-            STFT(4096, 1024),
-            MelFilter(96000, 4096, 128, 3000, 30000),
-            PCENLayer(128)
-        ),
-        nn.Sequential(
-            nn.Conv2d(1, 32, 3, bias=False),
-            nn.BatchNorm2d(32),
-            nn.LeakyReLU(0.01),
-            nn.Conv2d(32, 32, 3,bias=False),
-            nn.BatchNorm2d(32),
-            nn.MaxPool2d(3),
-            nn.LeakyReLU(0.01),
-            nn.Conv2d(32, 32, 3, bias=False),
-            nn.BatchNorm2d(32),
-            nn.LeakyReLU(0.01),
-            nn.Conv2d(32, 32, 3, bias=False),
-            nn.BatchNorm2d(32),
-            nn.LeakyReLU(0.01),
-            nn.Conv2d(32, 64, (19, 3), bias=False),
-            nn.BatchNorm2d(64),
-            nn.MaxPool2d(3),
-            nn.LeakyReLU(0.01),
-            nn.Dropout(p=.5),
-            nn.Conv2d(64, 256, (1, 9), bias=False),
-            nn.BatchNorm2d(256),
-            nn.LeakyReLU(0.01),
-            nn.Dropout(p=.5),
-            nn.Conv2d(256, 64, 1, bias=False),
-            nn.BatchNorm2d(64),
-            nn.LeakyReLU(0.01),
-            nn.Dropout(p=.5),
-            nn.Conv2d(64, 1, 1, bias=False),
-            nn.MaxPool2d((6, 1))
+    },
+    'megaptera' : {
+        'weights': 'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc',
+        'fs': 11025,
+        'archi': nn.Sequential(
+            nn.Sequential(
+                STFT(512, 64),
+                MelFilter(11025, 512, 64, 100, 3000),
+                PCENLayer(64)
+            ),
+            nn.Sequential(
+                nn.Conv2d(1, 32, 3, bias=False),
+                nn.BatchNorm2d(32),
+                nn.LeakyReLU(0.01),
+                nn.Conv2d(32, 32, 3,bias=False),
+                nn.BatchNorm2d(32),
+                nn.MaxPool2d(3),
+                nn.LeakyReLU(0.01),
+                nn.Conv2d(32, 32, 3, bias=False),
+                nn.BatchNorm2d(32),
+                nn.LeakyReLU(0.01),
+                nn.Conv2d(32, 32, 3, bias=False),
+                nn.BatchNorm2d(32),
+                nn.LeakyReLU(0.01),
+                nn.Conv2d(32, 64, (16, 3), bias=False),
+                nn.BatchNorm2d(64),
+                nn.MaxPool2d((1,3)),
+                nn.LeakyReLU(0.01),
+                nn.Dropout(p=.5),
+                nn.Conv2d(64, 256, (1, 9), bias=False),
+                nn.BatchNorm2d(256),
+                nn.LeakyReLU(0.01),
+                nn.Dropout(p=.5),
+                nn.Conv2d(256, 64, 1, bias=False),
+                nn.BatchNorm2d(64),
+                nn.LeakyReLU(0.01),
+                nn.Dropout(p=.5),
+                nn.Conv2d(64, 1, 1, bias=False)
+            )
+        )
+    },
+    'delphinid' : {
+        'weights': 'sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc',
+        'fs': 96000,
+        'archi': nn.Sequential(
+            nn.Sequential(
+                STFT(4096, 1024),
+                MelFilter(96000, 4096, 128, 3000, 30000),
+                PCENLayer(128)
+            ),
+            nn.Sequential(
+                nn.Conv2d(1, 32, 3, bias=False),
+                nn.BatchNorm2d(32),
+                nn.LeakyReLU(0.01),
+                nn.Conv2d(32, 32, 3,bias=False),
+                nn.BatchNorm2d(32),
+                nn.MaxPool2d(3),
+                nn.LeakyReLU(0.01),
+                nn.Conv2d(32, 32, 3, bias=False),
+                nn.BatchNorm2d(32),
+                nn.LeakyReLU(0.01),
+                nn.Conv2d(32, 32, 3, bias=False),
+                nn.BatchNorm2d(32),
+                nn.LeakyReLU(0.01),
+                nn.Conv2d(32, 64, (19, 3), bias=False),
+                nn.BatchNorm2d(64),
+                nn.MaxPool2d(3),
+                nn.LeakyReLU(0.01),
+                nn.Dropout(p=.5),
+                nn.Conv2d(64, 256, (1, 9), bias=False),
+                nn.BatchNorm2d(256),
+                nn.LeakyReLU(0.01),
+                nn.Dropout(p=.5),
+                nn.Conv2d(256, 64, 1, bias=False),
+                nn.BatchNorm2d(64),
+                nn.LeakyReLU(0.01),
+                nn.Dropout(p=.5),
+                nn.Conv2d(64, 1, 1, bias=False),
+                nn.MaxPool2d((6, 1))
+            )
+        )
+    },
+    'orcinus': {
+        'weights': 'train_fe76f_00085_85_0',
+        'fs': 22050,
+        'archi': nn.Sequential(
+            nn.Sequential(
+                STFT(1024, 128),
+                MelFilter(22050, 1024, 80, 300, 11025),
+                PCENLayer(80)
+            ),
+            nn.Sequential(
+                nn.Conv2d(1, 32, 3, bias=False),
+                nn.BatchNorm2d(32),
+                nn.LeakyReLU(0.01),
+                nn.Conv2d(32, 32, 3,bias=False),
+                nn.BatchNorm2d(32),
+                nn.MaxPool2d(3),
+                nn.LeakyReLU(0.01),
+                nn.Conv2d(32, 32, 3, bias=False),
+                nn.BatchNorm2d(32),
+                nn.LeakyReLU(0.01),
+                nn.Conv2d(32, 32, 3, bias=False),
+                nn.BatchNorm2d(32),
+                nn.LeakyReLU(0.01),
+                nn.Conv2d(32, 64, (19, 3), bias=False),
+                nn.BatchNorm2d(64),
+                nn.MaxPool2d(3),
+                nn.LeakyReLU(0.01),
+                nn.Dropout2d(p=.5),
+                nn.Conv2d(64, 256, (1, 9), bias=False),  # for 80 bands
+                nn.BatchNorm2d(256),
+                nn.LeakyReLU(0.01),
+                nn.Dropout2d(p=.5),
+                nn.Conv2d(256, 64, 1, bias=False),
+                nn.BatchNorm2d(64),
+                nn.Dropout2d(p=.5),
+                nn.LeakyReLU(0.01),
+                nn.Conv2d(64, 1, 1, bias=False),
+                nn.AdaptiveMaxPool2d(output_size=(1, 1))
+            )
         )
-    )
+    }
 }
diff --git a/run_CNN.py b/run_CNN.py
index 52606ad333192aa2202e8a369edeaab3217dd7a5..60789bb1e567ebb47fb3644ed36644495ab4e933 100644
--- a/run_CNN.py
+++ b/run_CNN.py
@@ -1,9 +1,8 @@
 import os
 import torch
 import models
-from scipy import signal
+from scipy import signal, special
 import soundfile as sf
-from torch.utils import data
 import numpy as np
 import pandas as pd
 from tqdm import tqdm
@@ -12,7 +11,6 @@ import argparse
 parser = argparse.ArgumentParser(description="Run this script to use a CNN for inference on a folder of audio files.")
 parser.add_argument('audio_folder', type=str, help='Path of the folder with audio files to process')
 parser.add_argument('specie', type=str, help='Target specie to detect', choices=['megaptera', 'delphinid', 'orcinus', 'physeter', 'balaenoptera'])
-parser.add_argument('pred_fn', type=str, help='Filename for the output table containing model predictions')
 parser.add_argument('-lensample', type=float, help='Length of the signal excerpts to process (sec)', default=5),
 parser.add_argument('-batch_size', type=int, help='Amount of samples to process at a time', default=32),
 parser.add_argument('-maxPool', help='Wether to keep only the maximal prediction of a sample or the full sequence', action='store_true'),
@@ -20,55 +18,32 @@ parser.add_argument('-no-maxPool', dest='maxPool', action='store_false')
 parser.set_defaults(maxPool=True)
 args = parser.parse_args()
 
-meta_model = {
-    'delphinid': {
-        'stdc': 'sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc',
-        'fs': 96000
-    },
-    'megaptera': {
-        'stdc': 'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc',
-        'fs': 11025
-    },
-    'orcinus': '',
-    'physeter': {
-        'stdc': 'stft_depthwise_ovs_128_k7_r1.stdc',
-        'fs': 50000
-    },
-    'balaenoptera': {
-        'stdc': 'dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc',
-        'fs': 200
-    }
-}[args.specie]
-
-
 def collate_fn(batch):
     batch = list(filter(lambda x: x is not None, batch))
-    return data.dataloader.default_collate(batch) if len(batch) > 0 else None
+    return torch.utils.data.dataloader.default_collate(batch) if len(batch) > 0 else None
 
 norm = lambda arr: (arr - np.mean(arr) ) / np.std(arr)
 
-class Dataset(data.Dataset):
+# Pytorch dataset class to load audio samples
+class Dataset(torch.utils.data.Dataset):
     def __init__(self, folder, fs, lensample):
         super(Dataset, self)
-        print('initializing dataset...')
+        self.fs, self.folder, self.lensample = fs, folder, lensample
         self.samples = []
-        for fn in os.listdir(folder):
+        for fn in tqdm(os.listdir(folder), desc='Dataset initialization', leave=False):
             try:
-                duration = sf.info(folder+fn).duration
+                info = sf.info(folder+fn)
+                duration, fs = info.duration, info.samplerate
+                self.samples.extend([{'fn':fn, 'offset':offset, 'fs':fs} for offset in np.arange(0, duration+.01-lensample, lensample)])
             except:
-                print(f'Skipping {fn} (unable to read as audio)')
                 continue
-            self.samples.extend([{'fn':fn, 'offset':offset} for offset in np.arange(0, duration+.01-lensample, lensample)])
-        self.fs, self.folder, self.lensample = fs, folder, lensample
-
     def __len__(self):
         return len(self.samples)
 
     def __getitem__(self, idx):
         sample = self.samples[idx]
-        fs = sf.info(self.folder+sample['fn']).samplerate
         try:
-            sig, fs = sf.read(self.folder+sample['fn'], start=int(sample['offset']*fs), stop=int((sample['offset']+self.lensample)*fs), always_2d=True)
+            sig, fs = sf.read(self.folder+sample['fn'], start=int(sample['offset']*sample['fs']), stop=int((sample['offset']+self.lensample)*sample['fs']), always_2d=True)
         except:
             print('Failed loading '+sample['fn'])
             return None
@@ -80,24 +55,27 @@ class Dataset(data.Dataset):
 
 
 # prepare model
-model = models.get[args.specie]
-model.load_state_dict(torch.load(f"weights/{meta_model['stdc']}"))
+model = models.get[args.specie]['archi']
+model.load_state_dict(torch.load(f"weights/{models.get[args.specie]['weights']}"))
 model.eval()
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 model.to(device)
 
 # prepare data loader and output storage for predictions
-loader = data.DataLoader(Dataset(args.audio_folder, meta_model['fs'], args.lensample), batch_size=args.batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4)
-out = pd.DataFrame(columns=['filename', 'offset', 'prediction'])
-fns, offsets, preds = [], [], []
+loader = torch.utils.data.DataLoader(Dataset(args.audio_folder, models.get[args.specie]['fs'], args.lensample),
+                                     batch_size=args.batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4)
 if len(loader) == 0:
-    print('Unable to open any audio file in the given folder')
+    print(f'Unable to open any audio file in the given folder {args.audiofolder}')
     exit()
 
+out = pd.DataFrame(columns=['filename', 'offset', 'prediction'])
+fns, offsets, preds = [], [], []
+
+# forward the model on each batch
 with torch.no_grad():
-    for x, meta in tqdm(loader):
+    for x, meta in tqdm(loader, desc='Model inference'):
         x = x.to(device)
-        pred = model(x).cpu().detach().numpy()
+        pred = special.expit(model(x).cpu().detach().numpy())
         if args.maxPool:
             pred = pred.max(axis=-1).reshape(len(x))
         else:
@@ -107,4 +85,9 @@ with torch.no_grad():
         offsets.extend(meta['offset'].numpy())
 
 out.filename, out.offset, out.prediction = fns, offsets, preds
-out.to_pickle(args.pred_fn)
+pred_fn = list(filter(lambda e: e!='', args.audio_folder.split('/')))[-1] + ('.csv' if args.maxPool else '.pkl')
+print(f'Saving results into {pred_fn}')
+if args.maxPool:
+    out.to_csv(pred_fn, index=False)
+else:
+    out.to_pickle(pred_fn)
\ No newline at end of file
diff --git a/weights/dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc b/weights/dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc
index dd62d9ee17bfda5a56610aaa10a36473a59034cc..b1311077fc9b6da5c5c8960e745d19611f88a465 100644
Binary files a/weights/dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc and b/weights/dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc differ
diff --git a/weights/stft_depthwise_ovs_128_k7_r1.stdc b/weights/stft_depthwise_ovs_128_k7_r1.stdc
index e131f2b536060353a50b14963462863205981944..9113b5e36c801160eda581d0b6b72c95c96f6348 100644
Binary files a/weights/stft_depthwise_ovs_128_k7_r1.stdc and b/weights/stft_depthwise_ovs_128_k7_r1.stdc differ
diff --git a/weights/train_fe76f_00085_85_0 b/weights/train_fe76f_00085_85_0
new file mode 100644
index 0000000000000000000000000000000000000000..3aa1c971c3cf16436fb51bfb29564a2eb5a7039b
Binary files /dev/null and b/weights/train_fe76f_00085_85_0 differ