Skip to content
Snippets Groups Projects
Commit 704877d1 authored by Paul Best's avatar Paul Best
Browse files

first structure arangement

parent ba7550d5
No related branches found
No related tags found
No related merge requests found
from torch import nn
import torch
import numpy as np
from torch import tensor, nn, exp, log, ones, stack
class PCENLayer(nn.Module):
class Log1p(torch.nn.Module):
"""
Applies log(1 + 10**a * x), with scale fixed or trainable.
"""
def __init__(self, a=0, trainable=False):
super(Log1p, self).__init__()
if trainable:
a = torch.nn.Parameter(torch.tensor(a, dtype=torch.get_default_dtype()))
self.a = a
self.trainable = trainable
def forward(self, x):
if self.trainable or self.a != 0:
x = torch.log1p(10 ** self.a * x)
return x
def extra_repr(self):
return 'trainable={}'.format(repr(self.trainable))
class PCENLayer(torch.nn.Module):
def __init__(self, num_bands,
s=0.025,
alpha=.8,
......@@ -13,11 +32,11 @@ class PCENLayer(nn.Module):
eps=1e-6,
init_smoother_from_data=True):
super(PCENLayer, self).__init__()
self.log_s = nn.Parameter( log(ones((1,1,num_bands)) * s))
self.log_alpha = nn.Parameter( log(ones((1,1,num_bands,1)) * alpha))
self.log_delta = nn.Parameter( log(ones((1,1,num_bands,1)) * delta))
self.log_r = nn.Parameter( log(ones((1,1,num_bands,1)) * r))
self.eps = tensor(eps)
self.log_s = torch.nn.Parameter( torch.log(torch.ones((1,1,num_bands)) * s))
self.log_alpha = torch.nn.Parameter( torch.log(torch.ones((1,1,num_bands,1)) * alpha))
self.log_delta = torch.nn.Parameter( torch.log(torch.ones((1,1,num_bands,1)) * delta))
self.log_r = torch.nn.Parameter( torch.log(torch.ones((1,1,num_bands,1)) * r))
self.eps = torch.tensor(eps)
self.init_smoother_from_data = init_smoother_from_data
def forward(self, input): # expected input (batch, channel, freqs, time)
......@@ -27,11 +46,11 @@ class PCENLayer(nn.Module):
filtered = [init]
for iframe in range(1, input.shape[-1]):
filtered.append( (1-exp(self.log_s)) * filtered[iframe-1] + exp(self.log_s) * input[:,:,:,iframe] )
filtered = stack(filtered).permute(1,2,3,0)
filtered.append( (1-torch.exp(self.log_s)) * filtered[iframe-1] + torch.exp(self.log_s) * input[:,:,:,iframe] )
filtered = torch.stack(filtered).permute(1,2,3,0)
# stable reformulation due to Vincent Lostanlen; original formula was:
alpha, delta, r = exp(self.log_alpha), exp(self.log_delta), exp(self.log_r)
alpha, delta, r = torch.exp(self.log_alpha), torch.exp(self.log_delta), torch.exp(self.log_r)
return (input / (self.eps + filtered)**alpha + delta)**r - delta**r
# filtered = exp(-alpha * (log(self.eps) + log(1 + filtered / self.eps)))
# return (input * filtered + delta)**r - delta**r
......@@ -80,7 +99,7 @@ def create_mel_filterbank(sample_rate, frame_len, num_bands, min_freq, max_freq,
return filterbank
class MelFilter(nn.Module):
class MelFilter(torch.nn.Module):
def __init__(self, sample_rate, winsize, num_bands, min_freq, max_freq):
super(MelFilter, self).__init__()
melbank = create_mel_filterbank(sample_rate, winsize, num_bands,
......@@ -112,7 +131,8 @@ class MelFilter(nn.Module):
self._buffers = buffers
return result
class STFT(nn.Module):
class STFT(torch.nn.Module):
def __init__(self, winsize, hopsize, complex=False):
super(STFT, self).__init__()
self.winsize = winsize
......@@ -154,77 +174,3 @@ class STFT(nn.Module):
# restore original batchsize and channels in case we mashed them
x = x.reshape((batchsize, channels, -1) + x.shape[2:]) #if channels > 1 else x.reshape((batchsize, -1) + x.shape[2:])
return x
HB_model = nn.Sequential(nn.Sequential(
STFT(512, 64),
MelFilter(11025, 512, 64, 100, 3000),
PCENLayer(64),
),
nn.Sequential(
nn.Conv2d(1, 32, 3, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 32, 3,bias=False),
nn.BatchNorm2d(32),
nn.MaxPool2d(3),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 32, 3, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 32, 3, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 64, (16, 3), bias=False),
nn.BatchNorm2d(64),
nn.MaxPool2d((1,3)),
nn.LeakyReLU(0.01),
nn.Dropout(p=.5),
nn.Conv2d(64, 256, (1, 9), bias=False), # for 80 bands
nn.BatchNorm2d(256),
nn.LeakyReLU(0.01),
nn.Dropout(p=.5),
nn.Conv2d(256, 64, 1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.01),
nn.Dropout(p=.5),
nn.Conv2d(64, 1, 1, bias=False)
)
)
delphi_model = nn.Sequential(nn.Sequential(
STFT(4096, 1024),
MelFilter(96000, 4096, 128, 3000, 30000),
PCENLayer(128),
),
nn.Sequential(
nn.Conv2d(1, 32, 3, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 32, 3,bias=False),
nn.BatchNorm2d(32),
nn.MaxPool2d(3),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 32, 3, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 32, 3, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 64, (19, 3), bias=False),
nn.BatchNorm2d(64),
nn.MaxPool2d(3),
nn.LeakyReLU(0.01),
nn.Dropout(p=.5),
nn.Conv2d(64, 256, (1, 9), bias=False), # for 80 bands
nn.BatchNorm2d(256),
nn.LeakyReLU(0.01),
nn.Dropout(p=.5),
nn.Conv2d(256, 64, 1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.01),
nn.Dropout(p=.5),
nn.Conv2d(64, 1, 1, bias=False),
nn.MaxPool2d((6, 1))
)
)
from torch import nn
from frontend import STFT, MelFilter, PCENLayer, Log1p
get = {
'megaptera' : nn.Sequential(
nn.Sequential(
STFT(512, 64),
MelFilter(11025, 512, 64, 100, 3000),
PCENLayer(64)
),
nn.Sequential(
nn.Conv2d(1, 32, 3, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 32, 3,bias=False),
nn.BatchNorm2d(32),
nn.MaxPool2d(3),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 32, 3, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 32, 3, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 64, (16, 3), bias=False),
nn.BatchNorm2d(64),
nn.MaxPool2d((1,3)),
nn.LeakyReLU(0.01),
nn.Dropout(p=.5),
nn.Conv2d(64, 256, (1, 9), bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.01),
nn.Dropout(p=.5),
nn.Conv2d(256, 64, 1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.01),
nn.Dropout(p=.5),
nn.Conv2d(64, 1, 1, bias=False),
nn.MaxPool2d((6, 1))
)
),
'delphinid' : nn.Sequential(
nn.Sequential(
STFT(4096, 1024),
MelFilter(96000, 4096, 128, 3000, 30000),
PCENLayer(128)
),
nn.Sequential(
nn.Conv2d(1, 32, 3, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 32, 3,bias=False),
nn.BatchNorm2d(32),
nn.MaxPool2d(3),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 32, 3, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 32, 3, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 64, (19, 3), bias=False),
nn.BatchNorm2d(64),
nn.MaxPool2d(3),
nn.LeakyReLU(0.01),
nn.Dropout(p=.5),
nn.Conv2d(64, 256, (1, 9), bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.01),
nn.Dropout(p=.5),
nn.Conv2d(256, 64, 1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.01),
nn.Dropout(p=.5),
nn.Conv2d(64, 1, 1, bias=False),
nn.MaxPool2d((6, 1))
)
)
}
import os
import torch
import models
from scipy import signal
import soundfile as sf
from torch.utils import data
import numpy as np
import pandas as pd
from tqdm import tqdm
import argparse
parser = argparse.ArgumentParser(description="Run this script to use a CNN for inference on a folder of audio files.")
parser.add_argument('audio_folder', type=str, help='Path of the folder with audio files to process')
parser.add_argument('specie', type=str, help='Target specie to detect', choices=['megaptera', 'delphinid', 'orcinus', 'physeter', 'balaenoptera'])
parser.add_argument('pred_fn', type=str, help='Filename for the output table containing model predictions')
parser.add_argument('-lensample', type=float, help='Length of the signal excerpts to process (sec)', default=5),
parser.add_argument('-batchsize', type=int, help='Amount of samples to process at a time', default=32),
parser.add_argument('-maxPool', type=bool, help='Wether to keep only the maximal prediction of a sample or the full sequence', default=True),
args = parser.parse_args()
meta_model = {
'delphinid': {
'stdc':'sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc',
'fs': 96000
},
'megaptera': {
'stdc':'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc',
'fs':11025
},
'orcinus': '',
'physeter': '',
'balaenoptera': ''
}[args.specie]
def collate_fn(batch):
batch = list(filter(lambda x: x is not None, batch))
return data.dataloader.default_collate(batch) if len(batch) > 0 else None
norm = lambda arr: (arr - np.mean(arr) ) / np.std(arr)
def run(folder, stdcfile, model, fs, lensample, batch_size, maxPool):
model.load_state_dict(torch.load(stdcfile))
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
out = pd.DataFrame(columns=['fn', 'offset', 'pred'])
fns, offsets, preds = [], [], []
loader = data.DataLoader(Dataset(folder, fs, lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4)
with torch.no_grad():
for x, meta in tqdm(loader):
x = x.to(device)
pred = model(x).cpu().detach().numpy()
if maxPool:
pred = np.maximum(pred)
else:
pred.reshape(len(x), -1)
fns.extend(meta['fn'])
offsets.extend(meta['offset'].numpy())
preds.extend(pred)
out.fn, out.offset, out.pred = fns, offsets, preds
return out
class Dataset(data.Dataset):
def __init__(self, folder, fs, lensample):
super(Dataset, self)
print('initializing dataset...')
self.samples = []
for fn in os.listdir(folder):
try:
duration = sf.info(folder.fn).duration
except:
print(f'Skipping {fn} (unable to read)')
continue
for offset in np.arange(0, duration+.01-lensample, lensample):
self.samples.append({'fn':fn, 'offset':offset})
self.fs, self.folder, self.lensample = fs, folder, lensample
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
fs = sf.info(self.folder+sample['fn']).samplerate
try:
sig, fs = sf.read(self.folder+sample['fn'], start=int(sample['offset']*fs), stop=int((sample['offset']+self.lensample)*fs), always_2d=True)
except:
print('Failed loading '+sample['fn'])
return None
sig = sig[:,0]
if fs != self.fs:
sig = signal.resample(sig, self.lensample*self.fs)
sig = norm(sig)
return torch.tensor(sig).float(), sample
preds = run(args.audio_folder,
meta_model['stdc'],
models.get[args.specie],
meta_model['fs'],
batch_size=args.batch_size,
lensample=args.lensample,
maxPool=args.maxPool
)
preds.to_pickle(args.pred_fn)
from model import HB_model
from scipy import signal
import soundfile as sf
from torch import load, no_grad, tensor, device, cuda
from torch.utils import data
import numpy as np
import pandas as pd
from tqdm import tqdm
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('files', type=str, nargs='+')
parser.add_argument('-outfn', type=str, default='HB_preds.pkl')
args = parser.parse_args()
stdc = 'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc'
def collate_fn(batch):
batch = list(filter(lambda x: x is not None, batch))
return data.dataloader.default_collate(batch) if len(batch) > 0 else None
def run(files, stdcfile, model, folder, pool=False, lensample=5, batch_size=32):
model.load_state_dict(load(stdcfile))
model.eval()
cuda0 = device('cuda' if cuda.is_available() else 'cpu')
model.to(cuda0)
out = pd.DataFrame(columns=['fn', 'offset', 'pred'])
fns, offsets, preds = [], [], []
with no_grad():
for x, meta in tqdm(data.DataLoader(Dataset(files, folder, lensample=lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8,prefetch_factor=4)):
x = x.to(cuda0, non_blocking=True)
pred = model(x)
temp = pd.DataFrame().from_dict(meta)
fns.extend(meta['fn'])
offsets.extend(meta['offset'].numpy())
preds.extend(pred.reshape(len(x), -1).cpu().detach().numpy())
# print(meta, temp, pred.reshape(len(x), -1).shape)
# temp['pred'] = pred.reshape(len(x), -1).cpu().detach()
# preds = preds.append(temp, ignore_index=True)
out.fn, out.offset, out.pred = fns, offsets, preds
#preds.pred = preds.pred.apply(np.array)
return out
class Dataset(data.Dataset):
def __init__(self, fns, folder, fe=11025, lenfile=120, lensample=50): # lenfile and lensample in seconds
super(Dataset, self)
print('init dataset')
self.samples = np.concatenate([[{'fn':fn, 'offset':offset} for offset in np.arange(0, sf.info(folder+fn).duration-lensample+1, lensample)] for fn in fns if sf.info(folder+fn).duration>10])
self.lensample = lensample
self.fe, self.folder = fe, folder
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
fs = sf.info(self.folder+sample['fn']).samplerate
try:
sig, fs = sf.read(self.folder+sample['fn'], start=max(0,int(sample['offset']*fs)), stop=int((sample['offset']+self.lensample)*fs))
except:
print('failed loading '+sample['fn'])
return None
if sig.ndim > 1:
sig = sig[:,0]
if len(sig) != fs*self.lensample:
print('to short file '+sample['fn']+' \n'+str(sig.shape))
return None
if fs != self.fe:
sig = signal.resample(sig, self.lensample*self.fe)
sig = norm(sig)
return tensor(sig).float(), sample
def norm(arr):
return (arr - np.mean(arr) ) / np.std(arr)
preds = run(args.files, stdc, HBmodel, './', batch_size=3, lensample=50)
preds.to_pickle(args.outfn)
from model import delphi_model
from scipy import signal
import soundfile as sf
from torch import load, no_grad, tensor, device, cuda
from torch.utils import data
import numpy as np
import pandas as pd
from tqdm import tqdm
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('files', type=str, nargs='+')
parser.add_argument('-outfn', type=str, default='delphi_preds.pkl')
args = parser.parse_args()
stdc = 'sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc'
def collate_fn(batch):
batch = list(filter(lambda x: x is not None, batch))
return data.dataloader.default_collate(batch) if len(batch) > 0 else None
def run(files, stdcfile, model, folder, fe=96000, lensample=5, batch_size=32):
model.load_state_dict(load(stdcfile))
model.eval()
cuda0 = device('cuda' if cuda.is_available() else 'cpu')
model.to(cuda0)
out = pd.DataFrame(columns=['fn', 'offset', 'pred'])
fns, offsets, preds = [], [], []
with no_grad():
for x, meta in tqdm(data.DataLoader(Dataset(files, folder, fe=fe, lensample=lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8,prefetch_factor=4)):
x = x.to(cuda0, non_blocking=True)
pred = model(x)
temp = pd.DataFrame().from_dict(meta)
fns.extend(meta['fn'])
offsets.extend(meta['offset'].numpy())
preds.extend(pred.reshape(len(x), -1).cpu().detach().numpy())
# print(meta, temp, pred.reshape(len(x), -1).shape)
# temp['pred'] = pred.reshape(len(x), -1).cpu().detach()
# preds = preds.append(temp, ignore_index=True)
out.fn, out.offset, out.pred = fns, offsets, preds
#preds.pred = preds.pred.apply(np.array)
return out
class Dataset(data.Dataset):
def __init__(self, fns, folder, fe=96000, lenfile=120, lensample=50): # lenfile and lensample in seconds
super(Dataset, self)
print('init dataset')
self.samples = np.concatenate([[{'fn':fn, 'offset':offset} for offset in np.arange(0, sf.info(folder+fn).duration-lensample+1, lensample)] for fn in fns if sf.info(folder+fn).duration>10])
self.lensample = lensample
self.fe, self.folder = fe, folder
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
fs = sf.info(self.folder+sample['fn']).samplerate
try:
sig, fs = sf.read(self.folder+sample['fn'], start=max(0,int(sample['offset']*fs)), stop=int((sample['offset']+self.lensample)*fs))
except:
print('failed loading '+sample['fn'])
return None
if sig.ndim > 1:
sig = sig[:,0]
if len(sig) != fs*self.lensample:
print('to short file '+sample['fn']+' \n'+str(sig.shape))
return None
if fs != self.fe:
sig = signal.resample(sig, self.lensample*self.fe)
sig = norm(sig)
return tensor(sig).float(), sample
def norm(arr):
return (arr - np.mean(arr) ) / np.std(arr)
preds = run(args.files, stdc, delphi_model, './', batch_size=3, lensample=50)
preds.to_pickle(args.outfn)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment