diff --git a/models.py b/models.py index 75660662c698d0c32e6f7f09ea74a560f8caafa6..730b943f1e1d2aa0b8f0f3a9ffd7a0c23039d28f 100644 --- a/models.py +++ b/models.py @@ -8,7 +8,7 @@ class depthwise_separable_conv1d(nn.Module): self.depthwise = nn.Conv1d(nin, nin, kernel_size=kernel, padding=padding, stride=stride, groups=nin) self.pointwise = nn.Conv1d(nin, nout, kernel_size=1) def forward(self, x): - out = self.depthwise(x) + out = self.depthwise(x.squeeze(1)) out = self.pointwise(out) return out @@ -27,105 +27,162 @@ BALAENOPTERA_NFEAT = 128 BALAENOPTERA_KERNEL = 5 get = { - 'physeter' : nn.Sequential( - STFT(512, 256), - MelFilter(50000, 512, 64, 2000, 25000), - Log1p(), - depthwise_separable_conv1d(64, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2), - nn.BatchNorm1d(PHYSETER_NFEAT), - nn.LeakyReLU(), - Dropout1d(), - depthwise_separable_conv1d(PHYSETER_NFEAT, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2), - nn.BatchNorm1d(PHYSETER_NFEAT), - nn.LeakyReLU(), - Dropout1d(), - depthwise_separable_conv1d(PHYSETER_NFEAT, 1, PHYSETER_KERNEL, stride=2) - ), - 'balaenoptera': nn.Sequential( - STFT(256, 32), - MelFilter(200, 256, 128, 0, 100), - Log1p(), - depthwise_separable_conv1d(128, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2), - nn.BatchNorm1d(BALAENOPTERA_NFEAT), - nn.LeakyReLU(), - Dropout1d(), - depthwise_separable_conv1d(BALAENOPTERA_NFEAT, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2), - nn.BatchNorm1d(BALAENOPTERA_NFEAT), - nn.LeakyReLU(), - Dropout1d(), - depthwise_separable_conv1d(BALAENOPTERA_NFEAT, 1, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2) - ), - 'megaptera' : nn.Sequential( - nn.Sequential( - STFT(512, 64), - MelFilter(11025, 512, 64, 100, 3000), - PCENLayer(64) + 'physeter': { + 'weights': 'stft_depthwise_ovs_128_k7_r1.stdc', + 'fs': 50000, + 'archi': nn.Sequential( + STFT(512, 256), + MelFilter(50000, 512, 64, 2000, 25000), + Log1p(trainable=True), + depthwise_separable_conv1d(64, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2), + nn.BatchNorm1d(PHYSETER_NFEAT), + nn.LeakyReLU(), + Dropout1d(), + depthwise_separable_conv1d(PHYSETER_NFEAT, PHYSETER_NFEAT, PHYSETER_KERNEL, stride=2), + nn.BatchNorm1d(PHYSETER_NFEAT), + nn.LeakyReLU(), + Dropout1d(), + depthwise_separable_conv1d(PHYSETER_NFEAT, 1, PHYSETER_KERNEL, stride=2) ), - nn.Sequential( - nn.Conv2d(1, 32, 3, bias=False), - nn.BatchNorm2d(32), - nn.LeakyReLU(0.01), - nn.Conv2d(32, 32, 3,bias=False), - nn.BatchNorm2d(32), - nn.MaxPool2d(3), - nn.LeakyReLU(0.01), - nn.Conv2d(32, 32, 3, bias=False), - nn.BatchNorm2d(32), - nn.LeakyReLU(0.01), - nn.Conv2d(32, 32, 3, bias=False), - nn.BatchNorm2d(32), - nn.LeakyReLU(0.01), - nn.Conv2d(32, 64, (16, 3), bias=False), - nn.BatchNorm2d(64), - nn.MaxPool2d((1,3)), - nn.LeakyReLU(0.01), - nn.Dropout(p=.5), - nn.Conv2d(64, 256, (1, 9), bias=False), - nn.BatchNorm2d(256), - nn.LeakyReLU(0.01), - nn.Dropout(p=.5), - nn.Conv2d(256, 64, 1, bias=False), - nn.BatchNorm2d(64), - nn.LeakyReLU(0.01), - nn.Dropout(p=.5), - nn.Conv2d(64, 1, 1, bias=False) + }, + 'balaenoptera': { + 'weights': 'dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc', + 'fs': 200, + 'archi': nn.Sequential( + STFT(256, 32), + MelFilter(200, 256, 128, 0, 100), + Log1p(trainable=True), + depthwise_separable_conv1d(128, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2), + nn.BatchNorm1d(BALAENOPTERA_NFEAT), + nn.LeakyReLU(), + Dropout1d(), + depthwise_separable_conv1d(BALAENOPTERA_NFEAT, BALAENOPTERA_NFEAT, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2), + nn.BatchNorm1d(BALAENOPTERA_NFEAT), + nn.LeakyReLU(), + Dropout1d(), + depthwise_separable_conv1d(BALAENOPTERA_NFEAT, 1, kernel=BALAENOPTERA_KERNEL, padding=BALAENOPTERA_KERNEL//2) ) - ), - 'delphinid' : nn.Sequential( - nn.Sequential( - STFT(4096, 1024), - MelFilter(96000, 4096, 128, 3000, 30000), - PCENLayer(128) - ), - nn.Sequential( - nn.Conv2d(1, 32, 3, bias=False), - nn.BatchNorm2d(32), - nn.LeakyReLU(0.01), - nn.Conv2d(32, 32, 3,bias=False), - nn.BatchNorm2d(32), - nn.MaxPool2d(3), - nn.LeakyReLU(0.01), - nn.Conv2d(32, 32, 3, bias=False), - nn.BatchNorm2d(32), - nn.LeakyReLU(0.01), - nn.Conv2d(32, 32, 3, bias=False), - nn.BatchNorm2d(32), - nn.LeakyReLU(0.01), - nn.Conv2d(32, 64, (19, 3), bias=False), - nn.BatchNorm2d(64), - nn.MaxPool2d(3), - nn.LeakyReLU(0.01), - nn.Dropout(p=.5), - nn.Conv2d(64, 256, (1, 9), bias=False), - nn.BatchNorm2d(256), - nn.LeakyReLU(0.01), - nn.Dropout(p=.5), - nn.Conv2d(256, 64, 1, bias=False), - nn.BatchNorm2d(64), - nn.LeakyReLU(0.01), - nn.Dropout(p=.5), - nn.Conv2d(64, 1, 1, bias=False), - nn.MaxPool2d((6, 1)) + }, + 'megaptera' : { + 'weights': 'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc', + 'fs': 11025, + 'archi': nn.Sequential( + nn.Sequential( + STFT(512, 64), + MelFilter(11025, 512, 64, 100, 3000), + PCENLayer(64) + ), + nn.Sequential( + nn.Conv2d(1, 32, 3, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 32, 3,bias=False), + nn.BatchNorm2d(32), + nn.MaxPool2d(3), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 32, 3, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 32, 3, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 64, (16, 3), bias=False), + nn.BatchNorm2d(64), + nn.MaxPool2d((1,3)), + nn.LeakyReLU(0.01), + nn.Dropout(p=.5), + nn.Conv2d(64, 256, (1, 9), bias=False), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.01), + nn.Dropout(p=.5), + nn.Conv2d(256, 64, 1, bias=False), + nn.BatchNorm2d(64), + nn.LeakyReLU(0.01), + nn.Dropout(p=.5), + nn.Conv2d(64, 1, 1, bias=False) + ) + ) + }, + 'delphinid' : { + 'weights': 'sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc', + 'fs': 96000, + 'archi': nn.Sequential( + nn.Sequential( + STFT(4096, 1024), + MelFilter(96000, 4096, 128, 3000, 30000), + PCENLayer(128) + ), + nn.Sequential( + nn.Conv2d(1, 32, 3, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 32, 3,bias=False), + nn.BatchNorm2d(32), + nn.MaxPool2d(3), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 32, 3, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 32, 3, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 64, (19, 3), bias=False), + nn.BatchNorm2d(64), + nn.MaxPool2d(3), + nn.LeakyReLU(0.01), + nn.Dropout(p=.5), + nn.Conv2d(64, 256, (1, 9), bias=False), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.01), + nn.Dropout(p=.5), + nn.Conv2d(256, 64, 1, bias=False), + nn.BatchNorm2d(64), + nn.LeakyReLU(0.01), + nn.Dropout(p=.5), + nn.Conv2d(64, 1, 1, bias=False), + nn.MaxPool2d((6, 1)) + ) + ) + }, + 'orcinus': { + 'weights': 'train_fe76f_00085_85_0', + 'fs': 22050, + 'archi': nn.Sequential( + nn.Sequential( + STFT(1024, 128), + MelFilter(22050, 1024, 80, 300, 11025), + PCENLayer(80) + ), + nn.Sequential( + nn.Conv2d(1, 32, 3, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 32, 3,bias=False), + nn.BatchNorm2d(32), + nn.MaxPool2d(3), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 32, 3, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 32, 3, bias=False), + nn.BatchNorm2d(32), + nn.LeakyReLU(0.01), + nn.Conv2d(32, 64, (19, 3), bias=False), + nn.BatchNorm2d(64), + nn.MaxPool2d(3), + nn.LeakyReLU(0.01), + nn.Dropout2d(p=.5), + nn.Conv2d(64, 256, (1, 9), bias=False), # for 80 bands + nn.BatchNorm2d(256), + nn.LeakyReLU(0.01), + nn.Dropout2d(p=.5), + nn.Conv2d(256, 64, 1, bias=False), + nn.BatchNorm2d(64), + nn.Dropout2d(p=.5), + nn.LeakyReLU(0.01), + nn.Conv2d(64, 1, 1, bias=False), + nn.AdaptiveMaxPool2d(output_size=(1, 1)) + ) ) - ) + } } diff --git a/run_CNN.py b/run_CNN.py index 52606ad333192aa2202e8a369edeaab3217dd7a5..60789bb1e567ebb47fb3644ed36644495ab4e933 100644 --- a/run_CNN.py +++ b/run_CNN.py @@ -1,9 +1,8 @@ import os import torch import models -from scipy import signal +from scipy import signal, special import soundfile as sf -from torch.utils import data import numpy as np import pandas as pd from tqdm import tqdm @@ -12,7 +11,6 @@ import argparse parser = argparse.ArgumentParser(description="Run this script to use a CNN for inference on a folder of audio files.") parser.add_argument('audio_folder', type=str, help='Path of the folder with audio files to process') parser.add_argument('specie', type=str, help='Target specie to detect', choices=['megaptera', 'delphinid', 'orcinus', 'physeter', 'balaenoptera']) -parser.add_argument('pred_fn', type=str, help='Filename for the output table containing model predictions') parser.add_argument('-lensample', type=float, help='Length of the signal excerpts to process (sec)', default=5), parser.add_argument('-batch_size', type=int, help='Amount of samples to process at a time', default=32), parser.add_argument('-maxPool', help='Wether to keep only the maximal prediction of a sample or the full sequence', action='store_true'), @@ -20,55 +18,32 @@ parser.add_argument('-no-maxPool', dest='maxPool', action='store_false') parser.set_defaults(maxPool=True) args = parser.parse_args() -meta_model = { - 'delphinid': { - 'stdc': 'sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc', - 'fs': 96000 - }, - 'megaptera': { - 'stdc': 'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc', - 'fs': 11025 - }, - 'orcinus': '', - 'physeter': { - 'stdc': 'stft_depthwise_ovs_128_k7_r1.stdc', - 'fs': 50000 - }, - 'balaenoptera': { - 'stdc': 'dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc', - 'fs': 200 - } -}[args.specie] - - def collate_fn(batch): batch = list(filter(lambda x: x is not None, batch)) - return data.dataloader.default_collate(batch) if len(batch) > 0 else None + return torch.utils.data.dataloader.default_collate(batch) if len(batch) > 0 else None norm = lambda arr: (arr - np.mean(arr) ) / np.std(arr) -class Dataset(data.Dataset): +# Pytorch dataset class to load audio samples +class Dataset(torch.utils.data.Dataset): def __init__(self, folder, fs, lensample): super(Dataset, self) - print('initializing dataset...') + self.fs, self.folder, self.lensample = fs, folder, lensample self.samples = [] - for fn in os.listdir(folder): + for fn in tqdm(os.listdir(folder), desc='Dataset initialization', leave=False): try: - duration = sf.info(folder+fn).duration + info = sf.info(folder+fn) + duration, fs = info.duration, info.samplerate + self.samples.extend([{'fn':fn, 'offset':offset, 'fs':fs} for offset in np.arange(0, duration+.01-lensample, lensample)]) except: - print(f'Skipping {fn} (unable to read as audio)') continue - self.samples.extend([{'fn':fn, 'offset':offset} for offset in np.arange(0, duration+.01-lensample, lensample)]) - self.fs, self.folder, self.lensample = fs, folder, lensample - def __len__(self): return len(self.samples) def __getitem__(self, idx): sample = self.samples[idx] - fs = sf.info(self.folder+sample['fn']).samplerate try: - sig, fs = sf.read(self.folder+sample['fn'], start=int(sample['offset']*fs), stop=int((sample['offset']+self.lensample)*fs), always_2d=True) + sig, fs = sf.read(self.folder+sample['fn'], start=int(sample['offset']*sample['fs']), stop=int((sample['offset']+self.lensample)*sample['fs']), always_2d=True) except: print('Failed loading '+sample['fn']) return None @@ -80,24 +55,27 @@ class Dataset(data.Dataset): # prepare model -model = models.get[args.specie] -model.load_state_dict(torch.load(f"weights/{meta_model['stdc']}")) +model = models.get[args.specie]['archi'] +model.load_state_dict(torch.load(f"weights/{models.get[args.specie]['weights']}")) model.eval() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) # prepare data loader and output storage for predictions -loader = data.DataLoader(Dataset(args.audio_folder, meta_model['fs'], args.lensample), batch_size=args.batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4) -out = pd.DataFrame(columns=['filename', 'offset', 'prediction']) -fns, offsets, preds = [], [], [] +loader = torch.utils.data.DataLoader(Dataset(args.audio_folder, models.get[args.specie]['fs'], args.lensample), + batch_size=args.batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4) if len(loader) == 0: - print('Unable to open any audio file in the given folder') + print(f'Unable to open any audio file in the given folder {args.audiofolder}') exit() +out = pd.DataFrame(columns=['filename', 'offset', 'prediction']) +fns, offsets, preds = [], [], [] + +# forward the model on each batch with torch.no_grad(): - for x, meta in tqdm(loader): + for x, meta in tqdm(loader, desc='Model inference'): x = x.to(device) - pred = model(x).cpu().detach().numpy() + pred = special.expit(model(x).cpu().detach().numpy()) if args.maxPool: pred = pred.max(axis=-1).reshape(len(x)) else: @@ -107,4 +85,9 @@ with torch.no_grad(): offsets.extend(meta['offset'].numpy()) out.filename, out.offset, out.prediction = fns, offsets, preds -out.to_pickle(args.pred_fn) +pred_fn = list(filter(lambda e: e!='', args.audio_folder.split('/')))[-1] + ('.csv' if args.maxPool else '.pkl') +print(f'Saving results into {pred_fn}') +if args.maxPool: + out.to_csv(pred_fn, index=False) +else: + out.to_pickle(pred_fn) \ No newline at end of file diff --git a/weights/dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc b/weights/dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc index dd62d9ee17bfda5a56610aaa10a36473a59034cc..b1311077fc9b6da5c5c8960e745d19611f88a465 100644 Binary files a/weights/dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc and b/weights/dw_m128_brown_200Hzhps32_prod_w4_128_k5_r_sch97.stdc differ diff --git a/weights/stft_depthwise_ovs_128_k7_r1.stdc b/weights/stft_depthwise_ovs_128_k7_r1.stdc index e131f2b536060353a50b14963462863205981944..9113b5e36c801160eda581d0b6b72c95c96f6348 100644 Binary files a/weights/stft_depthwise_ovs_128_k7_r1.stdc and b/weights/stft_depthwise_ovs_128_k7_r1.stdc differ diff --git a/weights/train_fe76f_00085_85_0 b/weights/train_fe76f_00085_85_0 new file mode 100644 index 0000000000000000000000000000000000000000..3aa1c971c3cf16436fb51bfb29564a2eb5a7039b Binary files /dev/null and b/weights/train_fe76f_00085_85_0 differ