diff --git a/models.py b/models.py new file mode 100644 index 0000000000000000000000000000000000000000..0213ef63e238188d3872c489b0dadd995f1338b8 --- /dev/null +++ b/models.py @@ -0,0 +1,76 @@ +from torch import nn +import torch +import torchvision + + +resnet = torchvision.models.resnet50() +resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) +resnet.avgpool = nn.AdaptiveMaxPool2d(output_size=(1, 1)) +resnet.fc = nn.Linear(2048, 7) + +""" +Audio feature extraction routines. +Author: Jan Schluter +""" +class STFT(nn.Module): + def __init__(self, winsize, hopsize): + super(STFT, self).__init__() + self.winsize, self.hopsize = winsize, hopsize + self.register_buffer('window', torch.hann_window(winsize, periodic=False)) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + result = super(STFT, self).state_dict(destination, prefix, keep_vars) + # remove all buffers; we use them as cached constants + for k in self._buffers: + del result[prefix + k] + return result + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + # ignore stored buffers for backwards compatibility + for k in self._buffers: + state_dict.pop(prefix + k, None) + # temporarily hide the buffers; we do not want to restore them + buffers = self._buffers + self._buffers = {} + result = super(STFT, self)._load_from_state_dict(state_dict, prefix, *args, **kwargs) + self._buffers = buffers + return result + + def forward(self, x): + x = x.unsqueeze(1) + # we want each channel to be treated separately, so we mash + # up the channels and batch size and split them up afterwards + batchsize, channels = x.shape[:2] + x = x.reshape((-1,) + x.shape[2:]) + # we apply the STFT + x = torch.stft(x, self.winsize, self.hopsize, window=self.window, center=False, return_complex=False) + # we compute magnitudes, if requested + x = x.norm(p=2, dim=-1) + # restore original batchsize and channels in case we mashed them + x = x.reshape((batchsize, channels, -1) + x.shape[2:]) #if channels > 1 else x.reshape((batchsize, -1) + x.shape[2:]) + return x + +class Log1p(nn.Module): + """ + Applies log(1 + 10**a * x), with scale fixed or trainable. + """ + def __init__(self, a=0, trainable=False): + super(Log1p, self).__init__() + if trainable: + a = nn.Parameter(torch.tensor(a, dtype=torch.get_default_dtype())) + self.a = a + self.trainable = trainable + + def forward(self, x): + if self.trainable or self.a != 0: + x = torch.log1p(10 ** self.a * x) + return x + + def extra_repr(self): + return 'trainable={}'.format(repr(self.trainable)) + + +frontend_logstft = nn.Sequential( + STFT(256, 32), + Log1p(a=7, trainable=True) +) diff --git a/resnet_weights.stdc b/resnet_weights.stdc new file mode 100644 index 0000000000000000000000000000000000000000..c41bc4dc1843f6f8e93c662b45a910cc1be73f7a Binary files /dev/null and b/resnet_weights.stdc differ diff --git a/run_model.py b/run_model.py new file mode 100755 index 0000000000000000000000000000000000000000..d2d66f6909b204037866701a30aa71715d81ed58 --- /dev/null +++ b/run_model.py @@ -0,0 +1,112 @@ +from models import frontend_logstft, resnet +from scipy import signal, special +import soundfile as sf +import torch +import numpy as np +import pandas as pd +from tqdm import tqdm +import argparse + +parser = argparse.ArgumentParser(description='Script to run the blue and fin whale vocalization detector') +parser.add_argument('output_fn', type=str, help='filename for the csv containing predictions') +parser.add_argument('files', type=str, nargs='+', help='list of files to process') +parser.add_argument('--folder', type=str, help='source folder for the files to process', default='') +parser.add_argument('--chunklen', type=int, help='size for chunking the sound files (result will give per chunk predictions', default=20) +parser.add_argument('--batchsize', type=int, help='number of samples to process simultaneously (tune according to available memory space)', default=32) +args = parser.parse_args() + +labels = ['Bp_20Plus', 'Bm_D', 'Bp-Downsweep', 'Bp_20Hz', 'Bm_Ant-A', 'Bm_Ant-Z', 'Bm_Ant-B'] + + +def run(files, folder, chunklen, batch_size): + """ Runs the CNN model over a list of sound files, and returns a table with predictions + + Parameters + ---------- + files : list + list of string with names of the sound files to process + folder : str + path to load the files from + chunklen : float + size in second used for chunking the sound files. Ending incomplete chunks won't be processed + batch_size : int + number of samples to process simultaneously (tune according to available memory space) + + Returns + ------- + out + Pandas DataFrame containing the predictions + """ + # prepare the model + model = torch.nn.Sequential(frontend_logstft, resnet) + model.load_state_dict(torch.load('resnet_weights.stdc')) + model.eval() + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + model.to(device) + + # run the model over each batch and store predictions + fns, offsets, preds = [], [], [] + with torch.no_grad(): + for x, meta in tqdm(torch.utils.data.DataLoader(Dataset(files, folder, chunklen=chunklen), batch_size=batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4)): + pred = model(x.to(device)) + fns.extend(meta['fn']) + offsets.extend(meta['offset'].numpy()) + preds.extend(pred.view(len(x), 7).cpu().detach().numpy()) + + # prepare the output table, and eventually return it + out = pd.DataFrame(columns=['fn', 'offset']+['pred_'+l for l in labels]) + out.fn, out.offset = fns, offsets + for l, p in zip(labels, np.array(preds).T): + out['pred_'+l] = special.expit(p) # expit == sigmoid (normalise predictions between 0 and 1) + return out + + + +class Dataset(torch.utils.data.Dataset): + """ + Class used to load each chunk of signal into batches + + Attributes + ---------- + fns : list + filenames to process + folder : str + path to load the files from + chunklen : float + size in second used for chunking the sound files. Ending incomplete chunks won't be processed + """ + def __init__(self, fns, folder, chunklen): + super(Dataset, self) + print('dataset initialisation ...', end='') + self.samples = np.concatenate([[{'fn':fn, 'offset':offset} for offset in np.arange(0, sf.info(folder+fn).duration-chunklen+.01, chunklen)] for fn in fns]) + print('...done!') + self.fs, self.folder, self.chunklen = 250, folder, chunklen + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + sample = self.samples[idx] + fs = sf.info(self.folder+sample['fn']).samplerate + try: + sig, fs = sf.read(self.folder+sample['fn'], start=int(sample['offset']*fs), stop=int((sample['offset']+self.chunklen)*fs)) + except: + print('failed loading '+sample['fn']) + return None + sig = sig[:,0] if sig.ndim > 1 else sig + if fs != self.fs: + sig = signal.resample(sig, self.chunklen*self.fs) + return torch.tensor(norm(sig)).float(), sample + + +def norm(arr): + return (arr - np.mean(arr) ) / np.std(arr) + + +def collate_fn(batch): + batch = list(filter(lambda x: x is not None, batch)) + return torch.utils.data.dataloader.default_collate(batch) if len(batch) > 0 else None + + +out = run(args.files, args.folder, args.chunklen, args.batchsize) +out.to_csv(args.output_fn, index=False) \ No newline at end of file