diff --git a/README.md b/README.md
index 306b0c595bf9ff25ea7b343a492018bbb186272c..efaddd27f8793f7b216c371e06f499963acafc88 100644
--- a/README.md
+++ b/README.md
@@ -9,6 +9,6 @@ For example :
 `python run_CNN_HB.py file1.wav file2.wav -outfn predictions.pkl`
 
 This script relies on torch, pandas, numpy, scipy, and tqdm to run. Install dependencies with pip or conda.
-If a GPU and cuda are available on the current machine, process will run on GPU for faster computation.
+If a GPU and cuda are available on the current machine, processes will run on GPU for faster computation.
 
 paul.best@univ-tln.fr for more information
diff --git a/models.py b/models.py
index abf66e23d20ead2b787fd2b7e7531ac4d89e1a84..75660662c698d0c32e6f7f09ea74a560f8caafa6 100644
--- a/models.py
+++ b/models.py
@@ -2,8 +2,59 @@ from torch import nn
 from frontend import STFT, MelFilter, PCENLayer, Log1p
 
 
+class depthwise_separable_conv1d(nn.Module):
+    def __init__(self, nin, nout, kernel, padding=0, stride=1):
+        super(depthwise_separable_conv1d, self).__init__()
+        self.depthwise = nn.Conv1d(nin, nin, kernel_size=kernel, padding=padding, stride=stride, groups=nin)
+        self.pointwise = nn.Conv1d(nin, nout, kernel_size=1)
+    def forward(self, x):
+        out = self.depthwise(x)
+        out = self.pointwise(out)
+        return out
+
+class Dropout1d(nn.Module):
+    def __init__(self, pdropout=.25):
+        super(Dropout1d, self).__init__()
+        self.dropout = nn.Dropout2d(pdropout)
+    def forward(self, x):
+        x = x.unsqueeze(-1)
+        x = self.dropout(x)
+        return x.squeeze(-1)
+
+PHYSETER_NFEAT = 128
+PHYSETER_KERNEL = 7
+BALAENOPTERA_NFEAT = 128
+BALAENOPTERA_KERNEL = 5
 
 get = {
+    'physeter' : nn.Sequential(
+        STFT(512, 256),
+        MelFilter(50000, 512, 64, 2000, 25000),
+        Log1p(),
+        depthwise_separable_conv1d(64, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2),
+        nn.BatchNorm1d(PHYSETER_NFEAT),
+        nn.LeakyReLU(),
+        Dropout1d(),
+        depthwise_separable_conv1d(PHYSETER_NFEAT, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2),
+        nn.BatchNorm1d(PHYSETER_NFEAT),
+        nn.LeakyReLU(),
+        Dropout1d(),
+        depthwise_separable_conv1d(PHYSETER_NFEAT, 1, PHYSETER_KERNEL, stride=2)
+    ),
+    'balaenoptera': nn.Sequential(
+        STFT(256, 32),
+        MelFilter(200, 256, 128, 0, 100),
+        Log1p(),
+        depthwise_separable_conv1d(128, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2),
+        nn.BatchNorm1d(BALAENOPTERA_NFEAT),
+        nn.LeakyReLU(),
+        Dropout1d(),
+        depthwise_separable_conv1d(BALAENOPTERA_NFEAT, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2),
+        nn.BatchNorm1d(BALAENOPTERA_NFEAT),
+        nn.LeakyReLU(),
+        Dropout1d(),
+        depthwise_separable_conv1d(BALAENOPTERA_NFEAT, 1, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2)
+    ),
     'megaptera' : nn.Sequential(
         nn.Sequential(
             STFT(512, 64),
diff --git a/run_CNN.py b/run_CNN.py
index dcd92d060c63da859ca2b35161acb4194b960395..52606ad333192aa2202e8a369edeaab3217dd7a5 100644
--- a/run_CNN.py
+++ b/run_CNN.py
@@ -20,19 +20,24 @@ parser.add_argument('-no-maxPool', dest='maxPool', action='store_false')
 parser.set_defaults(maxPool=True)
 args = parser.parse_args()
 
-
 meta_model = {
     'delphinid': {
-        'stdc':'sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc',
+        'stdc': 'sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc',
         'fs': 96000
     },
     'megaptera': {
-        'stdc':'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc',
-        'fs':11025
+        'stdc': 'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc',
+        'fs': 11025
     },
     'orcinus': '',
-    'physeter': '',
-    'balaenoptera': ''
+    'physeter': {
+        'stdc': 'stft_depthwise_ovs_128_k7_r1.stdc',
+        'fs': 50000
+    },
+    'balaenoptera': {
+        'stdc': 'dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc',
+        'fs': 200
+    }
 }[args.specie]
 
 
@@ -42,33 +47,6 @@ def collate_fn(batch):
 
 norm = lambda arr: (arr - np.mean(arr) ) / np.std(arr)
 
-
-def run(folder, stdcfile, model, fs, lensample, batch_size, maxPool):
-    model.load_state_dict(torch.load(stdcfile))
-    model.eval()
-    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-    model.to(device)
-
-    out = pd.DataFrame(columns=['fn', 'offset', 'pred'])
-    fns, offsets, preds = [], [], []
-    loader = data.DataLoader(Dataset(folder, fs, lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4)
-    if len(loader) == 0:
-        print('Unable to open any audio file in the given folder')
-    with torch.no_grad():
-        for x, meta in tqdm(loader):
-            x = x.to(device)
-            pred = model(x).cpu().detach().numpy()
-            if maxPool:
-                pred = pred.max(axis=-1).reshape(len(x))
-            else:
-                pred = pred.reshape(len(x), -1)
-            fns.extend(meta['fn'])
-            offsets.extend(meta['offset'].numpy())
-            preds.extend(pred)
-    out.fn, out.offset, out.pred = fns, offsets, preds
-    return out
-
-
 class Dataset(data.Dataset):
     def __init__(self, folder, fs, lensample):
         super(Dataset, self)
@@ -78,10 +56,9 @@ class Dataset(data.Dataset):
             try:
                 duration = sf.info(folder+fn).duration
             except:
-                print(f'Skipping {fn} (unable to read)')
+                print(f'Skipping {fn} (unable to read as audio)')
                 continue
-            for offset in np.arange(0, duration+.01-lensample, lensample):
-                self.samples.append({'fn':fn, 'offset':offset})
+            self.samples.extend([{'fn':fn, 'offset':offset} for offset in np.arange(0, duration+.01-lensample, lensample)])
         self.fs, self.folder, self.lensample = fs, folder, lensample
 
     def __len__(self):
@@ -101,13 +78,33 @@ class Dataset(data.Dataset):
         sig = norm(sig)
         return torch.tensor(sig).float(), sample
 
-preds = run(args.audio_folder,
-            meta_model['stdc'],
-            models.get[args.specie],
-            meta_model['fs'],
-            args.lensample,
-            args.batch_size,
-            args.maxPool
-        )
 
-preds.to_pickle(args.pred_fn)
+# prepare model
+model = models.get[args.specie]
+model.load_state_dict(torch.load(f"weights/{meta_model['stdc']}"))
+model.eval()
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+model.to(device)
+
+# prepare data loader and output storage for predictions
+loader = data.DataLoader(Dataset(args.audio_folder, meta_model['fs'], args.lensample), batch_size=args.batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4)
+out = pd.DataFrame(columns=['filename', 'offset', 'prediction'])
+fns, offsets, preds = [], [], []
+if len(loader) == 0:
+    print('Unable to open any audio file in the given folder')
+    exit()
+
+with torch.no_grad():
+    for x, meta in tqdm(loader):
+        x = x.to(device)
+        pred = model(x).cpu().detach().numpy()
+        if args.maxPool:
+            pred = pred.max(axis=-1).reshape(len(x))
+        else:
+            pred = pred.reshape(len(x), -1)
+        preds.extend(pred)
+        fns.extend(meta['fn'])
+        offsets.extend(meta['offset'].numpy())
+
+out.filename, out.offset, out.prediction = fns, offsets, preds
+out.to_pickle(args.pred_fn)
diff --git a/weights/dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc b/weights/dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc
new file mode 100644
index 0000000000000000000000000000000000000000..dd62d9ee17bfda5a56610aaa10a36473a59034cc
Binary files /dev/null and b/weights/dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc differ
diff --git a/sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc b/weights/sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc
similarity index 100%
rename from sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc
rename to weights/sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc
diff --git a/sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc b/weights/sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc
similarity index 100%
rename from sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc
rename to weights/sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc
diff --git a/weights/stft_depthwise_ovs_128_k7_r1.stdc b/weights/stft_depthwise_ovs_128_k7_r1.stdc
new file mode 100644
index 0000000000000000000000000000000000000000..e131f2b536060353a50b14963462863205981944
Binary files /dev/null and b/weights/stft_depthwise_ovs_128_k7_r1.stdc differ