diff --git a/model.py b/frontend.py
similarity index 70%
rename from model.py
rename to frontend.py
index b3ba93dcf74976543e967fcff696abe5d7e0b097..aa50b4b4dc7681443cd231df6221787d05a3466b 100644
--- a/model.py
+++ b/frontend.py
@@ -1,10 +1,29 @@
-from torch import nn
+
 import torch
 import numpy as np
-from torch import tensor, nn, exp, log, ones, stack
 
 
-class PCENLayer(nn.Module):
+class Log1p(torch.nn.Module):
+    """
+    Applies log(1 + 10**a * x), with scale fixed or trainable.
+    """
+    def __init__(self, a=0, trainable=False):
+        super(Log1p, self).__init__()
+        if trainable:
+            a = torch.nn.Parameter(torch.tensor(a, dtype=torch.get_default_dtype()))
+        self.a = a
+        self.trainable = trainable
+
+    def forward(self, x):
+        if self.trainable or self.a != 0:
+            x = torch.log1p(10 ** self.a * x)
+        return x
+
+    def extra_repr(self):
+        return 'trainable={}'.format(repr(self.trainable))
+
+
+class PCENLayer(torch.nn.Module):
     def __init__(self, num_bands,
                  s=0.025,
                  alpha=.8,
@@ -13,11 +32,11 @@ class PCENLayer(nn.Module):
                  eps=1e-6,
                  init_smoother_from_data=True):
         super(PCENLayer, self).__init__()
-        self.log_s = nn.Parameter( log(ones((1,1,num_bands)) * s))
-        self.log_alpha = nn.Parameter( log(ones((1,1,num_bands,1)) * alpha))
-        self.log_delta = nn.Parameter( log(ones((1,1,num_bands,1)) * delta))
-        self.log_r = nn.Parameter( log(ones((1,1,num_bands,1)) * r))
-        self.eps = tensor(eps)
+        self.log_s = torch.nn.Parameter( torch.log(torch.ones((1,1,num_bands)) * s))
+        self.log_alpha = torch.nn.Parameter( torch.log(torch.ones((1,1,num_bands,1)) * alpha))
+        self.log_delta = torch.nn.Parameter( torch.log(torch.ones((1,1,num_bands,1)) * delta))
+        self.log_r = torch.nn.Parameter( torch.log(torch.ones((1,1,num_bands,1)) * r))
+        self.eps = torch.tensor(eps)
         self.init_smoother_from_data = init_smoother_from_data
 
     def forward(self, input): # expected input (batch, channel, freqs, time)
@@ -27,11 +46,11 @@ class PCENLayer(nn.Module):
 
         filtered = [init]
         for iframe in range(1, input.shape[-1]):
-            filtered.append( (1-exp(self.log_s)) * filtered[iframe-1] + exp(self.log_s) * input[:,:,:,iframe] )
-        filtered = stack(filtered).permute(1,2,3,0)
+            filtered.append( (1-torch.exp(self.log_s)) * filtered[iframe-1] + torch.exp(self.log_s) * input[:,:,:,iframe] )
+        filtered = torch.stack(filtered).permute(1,2,3,0)
 
         # stable reformulation due to Vincent Lostanlen; original formula was:
-        alpha, delta, r = exp(self.log_alpha), exp(self.log_delta), exp(self.log_r)
+        alpha, delta, r = torch.exp(self.log_alpha), torch.exp(self.log_delta), torch.exp(self.log_r)
         return (input / (self.eps + filtered)**alpha + delta)**r - delta**r
 #        filtered = exp(-alpha * (log(self.eps) + log(1 + filtered / self.eps)))
 #        return (input * filtered + delta)**r - delta**r
@@ -80,7 +99,7 @@ def create_mel_filterbank(sample_rate, frame_len, num_bands, min_freq, max_freq,
     return filterbank
 
 
-class MelFilter(nn.Module):
+class MelFilter(torch.nn.Module):
     def __init__(self, sample_rate, winsize, num_bands, min_freq, max_freq):
         super(MelFilter, self).__init__()
         melbank = create_mel_filterbank(sample_rate, winsize, num_bands,
@@ -112,7 +131,8 @@ class MelFilter(nn.Module):
         self._buffers = buffers
         return result
 
-class STFT(nn.Module):
+
+class STFT(torch.nn.Module):
     def __init__(self, winsize, hopsize, complex=False):
         super(STFT, self).__init__()
         self.winsize = winsize
@@ -154,77 +174,3 @@ class STFT(nn.Module):
         # restore original batchsize and channels in case we mashed them
         x = x.reshape((batchsize, channels, -1) + x.shape[2:]) #if channels > 1 else x.reshape((batchsize, -1) + x.shape[2:])
         return x
-
-
-HB_model = nn.Sequential(nn.Sequential(
-    STFT(512, 64),
-    MelFilter(11025, 512, 64, 100, 3000),
-    PCENLayer(64),
-  ),
-  nn.Sequential(
-    nn.Conv2d(1, 32, 3, bias=False),
-    nn.BatchNorm2d(32),
-    nn.LeakyReLU(0.01),
-    nn.Conv2d(32, 32, 3,bias=False),
-    nn.BatchNorm2d(32),
-    nn.MaxPool2d(3),
-    nn.LeakyReLU(0.01),
-    nn.Conv2d(32, 32, 3, bias=False),
-    nn.BatchNorm2d(32),
-    nn.LeakyReLU(0.01),
-    nn.Conv2d(32, 32, 3, bias=False),
-    nn.BatchNorm2d(32),
-    nn.LeakyReLU(0.01),
-    nn.Conv2d(32, 64, (16, 3), bias=False),
-    nn.BatchNorm2d(64),
-    nn.MaxPool2d((1,3)),
-    nn.LeakyReLU(0.01),
-    nn.Dropout(p=.5),
-    nn.Conv2d(64, 256, (1, 9), bias=False),  # for 80 bands
-    nn.BatchNorm2d(256),
-    nn.LeakyReLU(0.01),
-    nn.Dropout(p=.5),
-    nn.Conv2d(256, 64, 1, bias=False),
-    nn.BatchNorm2d(64),
-    nn.LeakyReLU(0.01),
-    nn.Dropout(p=.5),
-    nn.Conv2d(64, 1, 1, bias=False)
-  )
- )
-
-delphi_model = nn.Sequential(nn.Sequential(
-    STFT(4096, 1024),
-    MelFilter(96000, 4096, 128, 3000, 30000),
-    PCENLayer(128),
- ),
- nn.Sequential(
-  nn.Conv2d(1, 32, 3, bias=False),
-  nn.BatchNorm2d(32),
-  nn.LeakyReLU(0.01),
-  nn.Conv2d(32, 32, 3,bias=False),
-  nn.BatchNorm2d(32),
-  nn.MaxPool2d(3),
-  nn.LeakyReLU(0.01),
-  nn.Conv2d(32, 32, 3, bias=False),
-  nn.BatchNorm2d(32),
-  nn.LeakyReLU(0.01),
-  nn.Conv2d(32, 32, 3, bias=False),
-  nn.BatchNorm2d(32),
-  nn.LeakyReLU(0.01),
-  nn.Conv2d(32, 64, (19, 3), bias=False),
-  nn.BatchNorm2d(64),
-  nn.MaxPool2d(3),
-  nn.LeakyReLU(0.01),
-  nn.Dropout(p=.5),
-  nn.Conv2d(64, 256, (1, 9), bias=False),  # for 80 bands
-  nn.BatchNorm2d(256),
-  nn.LeakyReLU(0.01),
-  nn.Dropout(p=.5),
-  nn.Conv2d(256, 64, 1, bias=False),
-  nn.BatchNorm2d(64),
-  nn.LeakyReLU(0.01),
-  nn.Dropout(p=.5),
-  nn.Conv2d(64, 1, 1, bias=False),
-  nn.MaxPool2d((6, 1))
-  )
-)
diff --git a/models.py b/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfba07ad96e133530efd1a0794432e630f074962
--- /dev/null
+++ b/models.py
@@ -0,0 +1,81 @@
+from torch import nn
+from frontend import STFT, MelFilter, PCENLayer, Log1p
+
+
+
+get = {
+    'megaptera' : nn.Sequential(
+        nn.Sequential(
+            STFT(512, 64),
+            MelFilter(11025, 512, 64, 100, 3000),
+            PCENLayer(64)
+        ),
+        nn.Sequential(
+            nn.Conv2d(1, 32, 3, bias=False),
+            nn.BatchNorm2d(32),
+            nn.LeakyReLU(0.01),
+            nn.Conv2d(32, 32, 3,bias=False),
+            nn.BatchNorm2d(32),
+            nn.MaxPool2d(3),
+            nn.LeakyReLU(0.01),
+            nn.Conv2d(32, 32, 3, bias=False),
+            nn.BatchNorm2d(32),
+            nn.LeakyReLU(0.01),
+            nn.Conv2d(32, 32, 3, bias=False),
+            nn.BatchNorm2d(32),
+            nn.LeakyReLU(0.01),
+            nn.Conv2d(32, 64, (16, 3), bias=False),
+            nn.BatchNorm2d(64),
+            nn.MaxPool2d((1,3)),
+            nn.LeakyReLU(0.01),
+            nn.Dropout(p=.5),
+            nn.Conv2d(64, 256, (1, 9), bias=False),
+            nn.BatchNorm2d(256),
+            nn.LeakyReLU(0.01),
+            nn.Dropout(p=.5),
+            nn.Conv2d(256, 64, 1, bias=False),
+            nn.BatchNorm2d(64),
+            nn.LeakyReLU(0.01),
+            nn.Dropout(p=.5),
+            nn.Conv2d(64, 1, 1, bias=False),
+            nn.MaxPool2d((6, 1))
+        )
+    ),
+    'delphinid' : nn.Sequential(
+        nn.Sequential(
+            STFT(4096, 1024),
+            MelFilter(96000, 4096, 128, 3000, 30000),
+            PCENLayer(128)
+        ),
+        nn.Sequential(
+            nn.Conv2d(1, 32, 3, bias=False),
+            nn.BatchNorm2d(32),
+            nn.LeakyReLU(0.01),
+            nn.Conv2d(32, 32, 3,bias=False),
+            nn.BatchNorm2d(32),
+            nn.MaxPool2d(3),
+            nn.LeakyReLU(0.01),
+            nn.Conv2d(32, 32, 3, bias=False),
+            nn.BatchNorm2d(32),
+            nn.LeakyReLU(0.01),
+            nn.Conv2d(32, 32, 3, bias=False),
+            nn.BatchNorm2d(32),
+            nn.LeakyReLU(0.01),
+            nn.Conv2d(32, 64, (19, 3), bias=False),
+            nn.BatchNorm2d(64),
+            nn.MaxPool2d(3),
+            nn.LeakyReLU(0.01),
+            nn.Dropout(p=.5),
+            nn.Conv2d(64, 256, (1, 9), bias=False),
+            nn.BatchNorm2d(256),
+            nn.LeakyReLU(0.01),
+            nn.Dropout(p=.5),
+            nn.Conv2d(256, 64, 1, bias=False),
+            nn.BatchNorm2d(64),
+            nn.LeakyReLU(0.01),
+            nn.Dropout(p=.5),
+            nn.Conv2d(64, 1, 1, bias=False),
+            nn.MaxPool2d((6, 1))
+        )
+    )
+}
diff --git a/run_CNN.py b/run_CNN.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c09ad4a8b88bd713323362b6fb7e61a8443527e
--- /dev/null
+++ b/run_CNN.py
@@ -0,0 +1,110 @@
+import os
+import torch
+import models
+from scipy import signal
+import soundfile as sf
+from torch.utils import data
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+import argparse
+
+parser = argparse.ArgumentParser(description="Run this script to use a CNN for inference on a folder of audio files.")
+parser.add_argument('audio_folder', type=str, help='Path of the folder with audio files to process')
+parser.add_argument('specie', type=str, help='Target specie to detect', choices=['megaptera', 'delphinid', 'orcinus', 'physeter', 'balaenoptera'])
+parser.add_argument('pred_fn', type=str, help='Filename for the output table containing model predictions')
+parser.add_argument('-lensample', type=float, help='Length of the signal excerpts to process (sec)', default=5),
+parser.add_argument('-batchsize', type=int, help='Amount of samples to process at a time', default=32),
+parser.add_argument('-maxPool', type=bool, help='Wether to keep only the maximal prediction of a sample or the full sequence', default=True),
+
+args = parser.parse_args()
+
+meta_model = {
+    'delphinid': {
+        'stdc':'sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc',
+        'fs': 96000
+    },
+    'megaptera': {
+        'stdc':'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc',
+        'fs':11025
+    },
+    'orcinus': '',
+    'physeter': '',
+    'balaenoptera': ''
+}[args.specie]
+
+
+def collate_fn(batch):
+    batch = list(filter(lambda x: x is not None, batch))
+    return data.dataloader.default_collate(batch) if len(batch) > 0 else None
+
+norm = lambda arr: (arr - np.mean(arr) ) / np.std(arr)
+
+
+def run(folder, stdcfile, model, fs, lensample, batch_size, maxPool):
+    model.load_state_dict(torch.load(stdcfile))
+    model.eval()
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    model.to(device)
+
+    out = pd.DataFrame(columns=['fn', 'offset', 'pred'])
+    fns, offsets, preds = [], [], []
+    loader = data.DataLoader(Dataset(folder, fs, lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4)
+    with torch.no_grad():
+        for x, meta in tqdm(loader):
+            x = x.to(device)
+            pred = model(x).cpu().detach().numpy()
+            if maxPool:
+                pred = np.maximum(pred)
+            else:
+                pred.reshape(len(x), -1)
+            fns.extend(meta['fn'])
+            offsets.extend(meta['offset'].numpy())
+            preds.extend(pred)
+    out.fn, out.offset, out.pred = fns, offsets, preds
+    return out
+
+
+class Dataset(data.Dataset):
+    def __init__(self, folder, fs, lensample):
+        super(Dataset, self)
+        print('initializing dataset...')
+        self.samples = []
+        for fn in os.listdir(folder):
+            try:
+                duration = sf.info(folder.fn).duration
+            except:
+                print(f'Skipping {fn} (unable to read)')
+                continue
+            for offset in np.arange(0, duration+.01-lensample, lensample):
+                self.samples.append({'fn':fn, 'offset':offset})
+        self.fs, self.folder, self.lensample = fs, folder, lensample
+
+    def __len__(self):
+        return len(self.samples)
+
+    def __getitem__(self, idx):
+        sample = self.samples[idx]
+        fs = sf.info(self.folder+sample['fn']).samplerate
+        try:
+            sig, fs = sf.read(self.folder+sample['fn'], start=int(sample['offset']*fs), stop=int((sample['offset']+self.lensample)*fs), always_2d=True)
+        except:
+            print('Failed loading '+sample['fn'])
+            return None
+        sig = sig[:,0]
+        if fs != self.fs:
+            sig = signal.resample(sig, self.lensample*self.fs)
+        sig = norm(sig)
+        return torch.tensor(sig).float(), sample
+
+
+preds = run(args.audio_folder,
+            meta_model['stdc'],
+            models.get[args.specie],
+            meta_model['fs'],
+            batch_size=args.batch_size,
+            lensample=args.lensample,
+            maxPool=args.maxPool
+        )
+
+preds.to_pickle(args.pred_fn)
diff --git a/run_CNN_HB.py b/run_CNN_HB.py
deleted file mode 100644
index 3bb7fc235791639e989d407c7984d1934acc90af..0000000000000000000000000000000000000000
--- a/run_CNN_HB.py
+++ /dev/null
@@ -1,81 +0,0 @@
-from model import HB_model
-from scipy import signal
-import soundfile as sf
-from torch import load, no_grad, tensor, device, cuda
-from torch.utils import data
-import numpy as np
-import pandas as pd
-from tqdm import tqdm
-import argparse
-
-parser = argparse.ArgumentParser()
-parser.add_argument('files', type=str, nargs='+')
-parser.add_argument('-outfn', type=str, default='HB_preds.pkl')
-args = parser.parse_args()
-
-stdc = 'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc'
-
-def collate_fn(batch):
-    batch = list(filter(lambda x: x is not None, batch))
-    return data.dataloader.default_collate(batch) if len(batch) > 0 else None
-
-def run(files, stdcfile, model, folder, pool=False, lensample=5, batch_size=32):
-    model.load_state_dict(load(stdcfile))
-    model.eval()
-    cuda0 = device('cuda' if cuda.is_available() else 'cpu')
-    model.to(cuda0)
-
-    out = pd.DataFrame(columns=['fn', 'offset', 'pred'])
-    fns, offsets, preds = [], [], []
-    with no_grad():
-        for x, meta in tqdm(data.DataLoader(Dataset(files, folder, lensample=lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8,prefetch_factor=4)):
-            x = x.to(cuda0, non_blocking=True)
-            pred = model(x)
-            temp = pd.DataFrame().from_dict(meta)
-            fns.extend(meta['fn'])
-            offsets.extend(meta['offset'].numpy())
-            preds.extend(pred.reshape(len(x), -1).cpu().detach().numpy())
-#            print(meta, temp, pred.reshape(len(x), -1).shape)
-#            temp['pred'] = pred.reshape(len(x), -1).cpu().detach()
-#            preds = preds.append(temp, ignore_index=True)
-    out.fn, out.offset, out.pred = fns, offsets, preds
-    #preds.pred = preds.pred.apply(np.array)
-    return out
-
-
-
-class Dataset(data.Dataset):
-    def __init__(self, fns, folder, fe=11025, lenfile=120, lensample=50): # lenfile and lensample in seconds
-        super(Dataset, self)
-        print('init dataset')
-        self.samples = np.concatenate([[{'fn':fn, 'offset':offset} for offset in np.arange(0, sf.info(folder+fn).duration-lensample+1, lensample)] for fn in fns if sf.info(folder+fn).duration>10])
-        self.lensample = lensample
-        self.fe, self.folder = fe, folder
-
-    def __len__(self):
-        return len(self.samples)
-
-    def __getitem__(self, idx):
-        sample = self.samples[idx]
-        fs = sf.info(self.folder+sample['fn']).samplerate
-        try:
-            sig, fs = sf.read(self.folder+sample['fn'], start=max(0,int(sample['offset']*fs)), stop=int((sample['offset']+self.lensample)*fs))
-        except:
-            print('failed loading '+sample['fn'])
-            return None
-        if sig.ndim > 1:
-            sig = sig[:,0]
-        if len(sig) != fs*self.lensample:
-            print('to short file '+sample['fn']+' \n'+str(sig.shape))
-            return None
-        if fs != self.fe:
-            sig = signal.resample(sig, self.lensample*self.fe)
-
-        sig = norm(sig)
-        return tensor(sig).float(), sample
-
-def norm(arr):
-    return (arr - np.mean(arr) ) / np.std(arr)
-
-preds = run(args.files, stdc, HBmodel, './', batch_size=3, lensample=50)
-preds.to_pickle(args.outfn)
diff --git a/run_CNN_delphi.py b/run_CNN_delphi.py
deleted file mode 100644
index d40c04769018523c75dda9d5fde10c28ddcca56e..0000000000000000000000000000000000000000
--- a/run_CNN_delphi.py
+++ /dev/null
@@ -1,81 +0,0 @@
-from model import delphi_model
-from scipy import signal
-import soundfile as sf
-from torch import load, no_grad, tensor, device, cuda
-from torch.utils import data
-import numpy as np
-import pandas as pd
-from tqdm import tqdm
-import argparse
-
-parser = argparse.ArgumentParser()
-parser.add_argument('files', type=str, nargs='+')
-parser.add_argument('-outfn', type=str, default='delphi_preds.pkl')
-args = parser.parse_args()
-
-stdc = 'sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc'
-
-def collate_fn(batch):
-    batch = list(filter(lambda x: x is not None, batch))
-    return data.dataloader.default_collate(batch) if len(batch) > 0 else None
-
-def run(files, stdcfile, model, folder, fe=96000, lensample=5, batch_size=32):
-    model.load_state_dict(load(stdcfile))
-    model.eval()
-    cuda0 = device('cuda' if cuda.is_available() else 'cpu')
-    model.to(cuda0)
-
-    out = pd.DataFrame(columns=['fn', 'offset', 'pred'])
-    fns, offsets, preds = [], [], []
-    with no_grad():
-        for x, meta in tqdm(data.DataLoader(Dataset(files, folder, fe=fe, lensample=lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8,prefetch_factor=4)):
-            x = x.to(cuda0, non_blocking=True)
-            pred = model(x)
-            temp = pd.DataFrame().from_dict(meta)
-            fns.extend(meta['fn'])
-            offsets.extend(meta['offset'].numpy())
-            preds.extend(pred.reshape(len(x), -1).cpu().detach().numpy())
-#            print(meta, temp, pred.reshape(len(x), -1).shape)
-#            temp['pred'] = pred.reshape(len(x), -1).cpu().detach()
-#            preds = preds.append(temp, ignore_index=True)
-    out.fn, out.offset, out.pred = fns, offsets, preds
-    #preds.pred = preds.pred.apply(np.array)
-    return out
-
-
-
-class Dataset(data.Dataset):
-    def __init__(self, fns, folder, fe=96000, lenfile=120, lensample=50): # lenfile and lensample in seconds
-        super(Dataset, self)
-        print('init dataset')
-        self.samples = np.concatenate([[{'fn':fn, 'offset':offset} for offset in np.arange(0, sf.info(folder+fn).duration-lensample+1, lensample)] for fn in fns if sf.info(folder+fn).duration>10])
-        self.lensample = lensample
-        self.fe, self.folder = fe, folder
-
-    def __len__(self):
-        return len(self.samples)
-
-    def __getitem__(self, idx):
-        sample = self.samples[idx]
-        fs = sf.info(self.folder+sample['fn']).samplerate
-        try:
-            sig, fs = sf.read(self.folder+sample['fn'], start=max(0,int(sample['offset']*fs)), stop=int((sample['offset']+self.lensample)*fs))
-        except:
-            print('failed loading '+sample['fn'])
-            return None
-        if sig.ndim > 1:
-            sig = sig[:,0]
-        if len(sig) != fs*self.lensample:
-            print('to short file '+sample['fn']+' \n'+str(sig.shape))
-            return None
-        if fs != self.fe:
-            sig = signal.resample(sig, self.lensample*self.fe)
-
-        sig = norm(sig)
-        return tensor(sig).float(), sample
-
-def norm(arr):
-    return (arr - np.mean(arr) ) / np.std(arr)
-
-preds = run(args.files, stdc, delphi_model, './', batch_size=3, lensample=50)
-preds.to_pickle(args.outfn)