Commit 280b776f authored by Paul Best's avatar Paul Best
Browse files

functionnal

parent e01bb121
from torch import nn
import torch
import torchvision
resnet = torchvision.models.resnet50()
resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
resnet.avgpool = nn.AdaptiveMaxPool2d(output_size=(1, 1))
resnet.fc = nn.Linear(2048, 7)
"""
Audio feature extraction routines.
Author: Jan Schluter
"""
class STFT(nn.Module):
def __init__(self, winsize, hopsize):
super(STFT, self).__init__()
self.winsize, self.hopsize = winsize, hopsize
self.register_buffer('window', torch.hann_window(winsize, periodic=False))
def state_dict(self, destination=None, prefix='', keep_vars=False):
result = super(STFT, self).state_dict(destination, prefix, keep_vars)
# remove all buffers; we use them as cached constants
for k in self._buffers:
del result[prefix + k]
return result
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# ignore stored buffers for backwards compatibility
for k in self._buffers:
state_dict.pop(prefix + k, None)
# temporarily hide the buffers; we do not want to restore them
buffers = self._buffers
self._buffers = {}
result = super(STFT, self)._load_from_state_dict(state_dict, prefix, *args, **kwargs)
self._buffers = buffers
return result
def forward(self, x):
x = x.unsqueeze(1)
# we want each channel to be treated separately, so we mash
# up the channels and batch size and split them up afterwards
batchsize, channels = x.shape[:2]
x = x.reshape((-1,) + x.shape[2:])
# we apply the STFT
x = torch.stft(x, self.winsize, self.hopsize, window=self.window, center=False, return_complex=False)
# we compute magnitudes, if requested
x = x.norm(p=2, dim=-1)
# restore original batchsize and channels in case we mashed them
x = x.reshape((batchsize, channels, -1) + x.shape[2:]) #if channels > 1 else x.reshape((batchsize, -1) + x.shape[2:])
return x
class Log1p(nn.Module):
"""
Applies log(1 + 10**a * x), with scale fixed or trainable.
"""
def __init__(self, a=0, trainable=False):
super(Log1p, self).__init__()
if trainable:
a = nn.Parameter(torch.tensor(a, dtype=torch.get_default_dtype()))
self.a = a
self.trainable = trainable
def forward(self, x):
if self.trainable or self.a != 0:
x = torch.log1p(10 ** self.a * x)
return x
def extra_repr(self):
return 'trainable={}'.format(repr(self.trainable))
frontend_logstft = nn.Sequential(
STFT(256, 32),
Log1p(a=7, trainable=True)
)
from models import frontend_logstft, resnet
from scipy import signal, special
import soundfile as sf
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import argparse
parser = argparse.ArgumentParser(description='Script to run the blue and fin whale vocalization detector')
parser.add_argument('output_fn', type=str, help='filename for the csv containing predictions')
parser.add_argument('files', type=str, nargs='+', help='list of files to process')
parser.add_argument('--folder', type=str, help='source folder for the files to process', default='')
parser.add_argument('--chunklen', type=int, help='size for chunking the sound files (result will give per chunk predictions', default=20)
parser.add_argument('--batchsize', type=int, help='number of samples to process simultaneously (tune according to available memory space)', default=32)
args = parser.parse_args()
labels = ['Bp_20Plus', 'Bm_D', 'Bp-Downsweep', 'Bp_20Hz', 'Bm_Ant-A', 'Bm_Ant-Z', 'Bm_Ant-B']
def run(files, folder, chunklen, batch_size):
""" Runs the CNN model over a list of sound files, and returns a table with predictions
Parameters
----------
files : list
list of string with names of the sound files to process
folder : str
path to load the files from
chunklen : float
size in second used for chunking the sound files. Ending incomplete chunks won't be processed
batch_size : int
number of samples to process simultaneously (tune according to available memory space)
Returns
-------
out
Pandas DataFrame containing the predictions
"""
# prepare the model
model = torch.nn.Sequential(frontend_logstft, resnet)
model.load_state_dict(torch.load('resnet_weights.stdc'))
model.eval()
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
# run the model over each batch and store predictions
fns, offsets, preds = [], [], []
with torch.no_grad():
for x, meta in tqdm(torch.utils.data.DataLoader(Dataset(files, folder, chunklen=chunklen), batch_size=batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4)):
pred = model(x.to(device))
fns.extend(meta['fn'])
offsets.extend(meta['offset'].numpy())
preds.extend(pred.view(len(x), 7).cpu().detach().numpy())
# prepare the output table, and eventually return it
out = pd.DataFrame(columns=['fn', 'offset']+['pred_'+l for l in labels])
out.fn, out.offset = fns, offsets
for l, p in zip(labels, np.array(preds).T):
out['pred_'+l] = special.expit(p) # expit == sigmoid (normalise predictions between 0 and 1)
return out
class Dataset(torch.utils.data.Dataset):
"""
Class used to load each chunk of signal into batches
Attributes
----------
fns : list
filenames to process
folder : str
path to load the files from
chunklen : float
size in second used for chunking the sound files. Ending incomplete chunks won't be processed
"""
def __init__(self, fns, folder, chunklen):
super(Dataset, self)
print('dataset initialisation ...', end='')
self.samples = np.concatenate([[{'fn':fn, 'offset':offset} for offset in np.arange(0, sf.info(folder+fn).duration-chunklen+.01, chunklen)] for fn in fns])
print('...done!')
self.fs, self.folder, self.chunklen = 250, folder, chunklen
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
fs = sf.info(self.folder+sample['fn']).samplerate
try:
sig, fs = sf.read(self.folder+sample['fn'], start=int(sample['offset']*fs), stop=int((sample['offset']+self.chunklen)*fs))
except:
print('failed loading '+sample['fn'])
return None
sig = sig[:,0] if sig.ndim > 1 else sig
if fs != self.fs:
sig = signal.resample(sig, self.chunklen*self.fs)
return torch.tensor(norm(sig)).float(), sample
def norm(arr):
return (arr - np.mean(arr) ) / np.std(arr)
def collate_fn(batch):
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch) if len(batch) > 0 else None
out = run(args.files, args.folder, args.chunklen, args.batchsize)
out.to_csv(args.output_fn, index=False)
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment