Select Git revision
run_model.py 4.35 KiB
from models import frontend_logstft, resnet
from scipy import signal, special
import soundfile as sf
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import argparse
parser = argparse.ArgumentParser(description='Script to run the blue and fin whale vocalization detector')
parser.add_argument('output_fn', type=str, help='filename for the csv containing predictions')
parser.add_argument('files', type=str, nargs='+', help='list of files to process')
parser.add_argument('--folder', type=str, help='source folder for the files to process', default='')
parser.add_argument('--chunklen', type=int, help='size for chunking the sound files (result will give per chunk predictions', default=20)
parser.add_argument('--batchsize', type=int, help='number of samples to process simultaneously (tune according to available memory space)', default=32)
args = parser.parse_args()
labels = ['Bp_20Plus', 'Bm_D', 'Bp-Downsweep', 'Bp_20Hz', 'Bm_Ant-A', 'Bm_Ant-Z', 'Bm_Ant-B']
def run(files, folder, chunklen, batch_size):
""" Runs the CNN model over a list of sound files, and returns a table with predictions
Parameters
----------
files : list
list of string with names of the sound files to process
folder : str
path to load the files from
chunklen : float
size in second used for chunking the sound files. Ending incomplete chunks won't be processed
batch_size : int
number of samples to process simultaneously (tune according to available memory space)
Returns
-------
out
Pandas DataFrame containing the predictions
"""
# prepare the model
model = torch.nn.Sequential(frontend_logstft, resnet)
model.load_state_dict(torch.load('resnet_weights.stdc'))
model.eval()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
# run the model over each batch and store predictions
fns, offsets, preds = [], [], []
with torch.no_grad():
for x, meta in tqdm(torch.utils.data.DataLoader(Dataset(files, folder, chunklen=chunklen), batch_size=batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4)):
pred = model(x.to(device))
fns.extend(meta['fn'])
offsets.extend(meta['offset'].numpy())
preds.extend(pred.view(len(x), 7).cpu().detach().numpy())
# prepare the output table, and eventually return it
out = pd.DataFrame(columns=['fn', 'offset']+['pred_'+l for l in labels])
out.fn, out.offset = fns, offsets
for l, p in zip(labels, np.array(preds).T):
out['pred_'+l] = special.expit(p) # expit == sigmoid (normalise predictions between 0 and 1)
return out
class Dataset(torch.utils.data.Dataset):
"""
Class used to load each chunk of signal into batches
Attributes
----------
fns : list
filenames to process
folder : str
path to load the files from
chunklen : float
size in second used for chunking the sound files. Ending incomplete chunks won't be processed
"""
def __init__(self, fns, folder, chunklen):
super(Dataset, self)
print('dataset initialisation ...', end='')
self.samples = np.concatenate([[{'fn':fn, 'offset':offset} for offset in np.arange(0, sf.info(folder+fn).duration-chunklen+.01, chunklen)] for fn in fns])
print('...done!')
self.fs, self.folder, self.chunklen = 250, folder, chunklen
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
fs = sf.info(self.folder+sample['fn']).samplerate
try:
sig, fs = sf.read(self.folder+sample['fn'], start=int(sample['offset']*fs), stop=int((sample['offset']+self.chunklen)*fs))
except:
print('failed loading '+sample['fn'])
return None
sig = sig[:,0] if sig.ndim > 1 else sig
if fs != self.fs:
sig = signal.resample(sig, self.chunklen*self.fs)
return torch.tensor(norm(sig)).float(), sample
def norm(arr):
return (arr - np.mean(arr) ) / np.std(arr)
def collate_fn(batch):
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch) if len(batch) > 0 else None
out = run(args.files, args.folder, args.chunklen, args.batchsize)
out.to_csv(args.output_fn, index=False)