Skip to content
Snippets Groups Projects
Commit a28d2605 authored by Paul Best's avatar Paul Best
Browse files

interface for experiments with new species

parent f788a05d
Branches
Tags
No related merge requests found
import utils as u
import models
import numpy as np, pandas as pd, torch
import umap
from tqdm import tqdm
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute the AE projection of vocalizations once it was trained.")
parser.add_argument('modelname', type=str, help='Filename of the AE weights (.stdc)')
parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation")
parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck')
parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
encoder = models.sparrow_encoder(args.bottleneck)
decoder = models.sparrow_decoder(args.bottleneck, (4, 4) if args.nMel == 128 else (2, 4))
model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
model.load_state_dict(torch.load(args.modelname, map_location=device))
df = pd.read_csv(args.detections)
print('Computing AE projections...')
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8)
with torch.no_grad():
encodings, idxs = [], []
for x, idx in tqdm(loader):
encoding = model[:2](x.to(device))
idxs.extend(idx)
encodings.extend(encoding.cpu().detach())
idxs = np.array(idxs)
encodings = np.stack(encodings)
print('Computing UMAP projections...')
X = umap.UMAP(n_jobs=-1).fit_transform(encodings)
np.save('encodings_'+args.modelname[:-4]+'npy', {'encodings':encodings, 'idx':idxs, 'umap':X})
# Author : Jan Schlüter
from torch import nn
import torch
import numpy as np
def create_mel_filterbank(sample_rate, frame_len, num_bands, min_freq, max_freq,
norm=True, crop=False):
"""
Creates a mel filterbank of `num_bands` triangular filters, with the first
filter starting at `min_freq` and the last one stopping at `max_freq`.
Returns the filterbank as a matrix suitable for a dot product against
magnitude spectra created from samples at a sample rate of `sample_rate`
with a window length of `frame_len` samples. If `norm`, will normalize
each filter by its area. If `crop`, will exclude rows that exceed the
maximum frequency and are therefore zero.
"""
# mel-spaced peak frequencies
min_mel = 1127 * np.log1p(min_freq / 7000.0)
max_mel = 1127 * np.log1p(max_freq / 7000.0)
peaks_mel = torch.linspace(min_mel, max_mel, num_bands + 2)
peaks_hz = 7000 * (torch.expm1(peaks_mel / 1127))
peaks_bin = peaks_hz * frame_len / sample_rate
# create filterbank
input_bins = (frame_len // 2) + 1
if crop:
input_bins = min(input_bins,
int(np.ceil(max_freq * frame_len /
float(sample_rate))))
x = torch.arange(input_bins, dtype=peaks_bin.dtype)[:, np.newaxis]
l, c, r = peaks_bin[0:-2], peaks_bin[1:-1], peaks_bin[2:]
# triangles are the minimum of two linear functions f(x) = a*x + b
# left side of triangles: f(l) = 0, f(c) = 1 -> a=1/(c-l), b=-a*l
tri_left = (x - l) / (c - l)
# right side of triangles: f(c) = 1, f(r) = 0 -> a=1/(c-r), b=-a*r
tri_right = (x - r) / (c - r)
# combine by taking the minimum of the left and right sides
tri = torch.min(tri_left, tri_right)
# and clip to only keep positive values
filterbank = torch.clamp(tri, min=0)
# normalize by area
if norm:
filterbank /= filterbank.sum(0)
return filterbank
class MelFilter(nn.Module):
def __init__(self, sample_rate, winsize, num_bands, min_freq, max_freq):
super(MelFilter, self).__init__()
melbank = create_mel_filterbank(sample_rate, winsize, num_bands,
min_freq, max_freq, crop=True)
self.register_buffer('bank', melbank)
def forward(self, x):
x = x.transpose(-1, -2) # put fft bands last
x = x[..., :self.bank.shape[0]] # remove unneeded fft bands
x = x.matmul(self.bank) # turn fft bands into mel bands
x = x.transpose(-1, -2) # put time last
return x
def state_dict(self, destination=None, prefix='', keep_vars=False):
result = super(MelFilter, self).state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
# remove all buffers; we use them as cached constants
for k in self._buffers:
del result[prefix + k]
return result
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# ignore stored buffers for backwards compatibility
for k in self._buffers:
state_dict.pop(prefix + k, None)
# temporarily hide the buffers; we do not want to restore them
buffers = self._buffers
self._buffers = {}
result = super(MelFilter, self)._load_from_state_dict(state_dict, prefix, *args, **kwargs)
self._buffers = buffers
return result
class STFT(nn.Module):
def __init__(self, winsize, hopsize, complex=False):
super(STFT, self).__init__()
self.winsize = winsize
self.hopsize = hopsize
self.register_buffer('window',
torch.hann_window(winsize, periodic=False))
self.complex = complex
def state_dict(self, destination=None, prefix='', keep_vars=False):
result = super(STFT, self).state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
# remove all buffers; we use them as cached constants
for k in self._buffers:
del result[prefix + k]
return result
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# ignore stored buffers for backwards compatibility
for k in self._buffers:
state_dict.pop(prefix + k, None)
# temporarily hide the buffers; we do not want to restore them
buffers = self._buffers
self._buffers = {}
result = super(STFT, self)._load_from_state_dict(state_dict, prefix, *args, **kwargs)
self._buffers = buffers
return result
def forward(self, x):
x = x.unsqueeze(1)
# we want each channel to be treated separately, so we mash
# up the channels and batch size and split them up afterwards
batchsize, channels = x.shape[:2]
x = x.reshape((-1,) + x.shape[2:])
# we apply the STFT
x = torch.stft(x, self.winsize, self.hopsize, window=self.window,
center=False, return_complex=False)
# we compute magnitudes, if requested
if not self.complex:
x = x.norm(p=2, dim=-1)
# restore original batchsize and channels in case we mashed them
x = x.reshape((batchsize, channels, -1) + x.shape[2:]) #if channels > 1 else x.reshape((batchsize, -1) + x.shape[2:])
return x
class TemporalBatchNorm(nn.Module):
"""
Batch normalization of a (batch, channels, bands, time) tensor over all but
the previous to last dimension (the frequency bands).
"""
def __init__(self, num_bands):
super(TemporalBatchNorm, self).__init__()
self.bn = nn.BatchNorm1d(num_bands)
def forward(self, x):
shape = x.shape
# squash channels into the batch dimension
x = x.reshape((-1,) + x.shape[-2:])
# pass through 1D batch normalization
x = self.bn(x)
# restore squashed dimensions
return x.reshape(shape)
class Log1p(nn.Module):
"""
Applies log(1 + 10**a * x), with scale fixed or trainable.
"""
def __init__(self, a=0, trainable=False):
super(Log1p, self).__init__()
if trainable:
a = nn.Parameter(torch.tensor(a, dtype=torch.get_default_dtype()))
self.a = a
self.trainable = trainable
def forward(self, x):
if self.trainable or self.a != 0:
x = torch.log1p(10 ** self.a * x)
return x
def extra_repr(self):
return 'trainable={}'.format(repr(self.trainable))
import torchvision.models as torchmodels
from torch import nn
import utils as u
from filterbank import STFT, MelFilter, Log1p
vgg16 = torchmodels.vgg16(weights=torchmodels.VGG16_Weights.DEFAULT)
vgg16 = vgg16.features[:13]
for nm, mod in vgg16.named_modules():
if isinstance(mod, nn.MaxPool2d):
setattr(vgg16, nm, nn.AvgPool2d(2 ,2))
frontend = lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
STFT(nfft, int((sampleDur*sr - nfft)/128)),
MelFilter(sr, nfft, n_mel, sr//nfft, sr//2),
Log1p(7, trainable=False)
)
sparrow_encoder = lambda nfeat : nn.Sequential(
nn.Conv2d(1, 32, 3, stride=2, bias=False, padding=(1)),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.01),
nn.Conv2d(32, 64, 3, stride=2, bias=False, padding=1),
nn.BatchNorm2d(64),
nn.MaxPool2d((1, 2)),
nn.ReLU(True),
nn.Conv2d(64, 128, 3, stride=2, bias=False, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.Conv2d(128, 256, 3, stride=2, bias=False, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.Conv2d(256, nfeat, (3, 5), stride=2, padding=(1, 2)),
nn.AdaptiveMaxPool2d((1,1)),
u.Reshape(nfeat)
)
sparrow_decoder = lambda nfeat, shape : nn.Sequential(
nn.Linear(nfeat, nfeat*shape[0]*shape[1]),
u.Reshape(nfeat, *shape),
nn.ReLU(True),
nn.Upsample(scale_factor=2),
nn.Conv2d(nfeat, 256, (3, 3), bias=False, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.Conv2d(256, 256, (3, 3), bias=False, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.Upsample(scale_factor=2),
nn.Conv2d(256, 128, (3, 3), bias=False, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.Conv2d(128, 128, (3, 3), bias=False, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, (3, 3), bias=False, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.Conv2d(64, 64, (3, 3), bias=False, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.Upsample(scale_factor=2),
nn.Conv2d(64, 32, (3, 3), bias=False, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.Conv2d(32, 32, (3, 3), bias=False, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.Upsample(scale_factor=2),
nn.Conv2d(32, 32, (3, 3), bias=False, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.Conv2d(32, 1, (3, 3), bias=False, padding=1),
nn.ReLU(True)
)
import utils as u
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import torch, numpy as np, pandas as pd
import hdbscan
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, \
description="""Interface to visualize projected vocalizations (UMAP reduced AE embeddings), tune HDBSCAN parameters, and browse clusters by clicking on projected points.\n
If satisfying parameters are reached, the clusters can be plotted in .png folders by typing y after closing the projection plot.\n
For insights on how to tune HDBSCAN parameters, read https://hdbscan.readthedocs.io/en/latest/parameter_selection.html.\n
To enable sound playing when browsing points, make sure the sounddevice package is installed.""")
parser.add_argument('encodings', type=str, help='.npy file containing umap projections and their associated index in the detection.pkl table (built using compute_embeddings.py)')
parser.add_argument('detections', type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
parser.add_argument('audio_folder', type=str, help='Path to the folder with complete audio files')
parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
parser.add_argument('-min_cluster_size', type=int, default=50, help='Used for HDBSCAN clustering.')
parser.add_argument('-min_sample', type=int, default=5, help='Used for HDBSCAN clustering.')
parser.add_argument('-eps', type=float, default=0.0, help='Used for HDBSCAN clustering.')
args = parser.parse_args()
gpu = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
df = pd.read_csv(args.detections)
encodings = np.load(args.encodings, allow_pickle=True).item()
idxs, umap = encodings['idx'], encodings['umap']
# Use HDBSCAN to cluster the embedings (min_cluster_size and min_samples parameters need to be tuned)
df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size,
min_samples=args.min_sample,
core_dist_n_jobs=-1,
cluster_selection_epsilon=args.eps,
cluster_selection_method='leaf').fit_predict(umap)
df.cluster = df.cluster.astype(int)
figscat = plt.figure(figsize=(10, 5))
plt.title(f'{args.encodings} {args.min_cluster_size} {args.min_sample} {args.eps}')
plt.scatter(umap[:,0], umap[:,1], s=3, alpha=.2, c=df.loc[idxs].cluster, cmap='tab20')
plt.tight_layout()
plt.savefig('projection')
plt.show()
if input('\nType y to print cluster pngs.\n/!\ the cluster_pngs folder will be reset, backup if needed /!\ ') != 'y':
exit()
os.system('rm -R cluster_pngs/*')
for c, grp in df.groupby('cluster'):
if c == -1 or len(grp) > 10_000:
continue
os.system('mkdir -p cluster_pngs/'+str(c))
loader = torch.utils.data.DataLoader(u.Dataset(grp.sample(min(len(grp), 200)), args.audio_folder, args.SR, args.sampleDur), batch_size=1, num_workers=8)
with torch.no_grad():
for x, idx in tqdm(loader, leave=False, desc=str(c)):
plt.specgram(x.squeeze().numpy())
plt.tight_layout()
plt.savefig(f'cluster_pngs/{c}/{idx.squeeze().item()}')
plt.close()
from torchvision.utils import make_grid
import torch
import pandas as pd
import utils as u, models
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation")
parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck')
args = parser.parse_args()
df = pd.read_csv(args.detections)
print(f'Training using {len(df)} vocalizations')
nepoch = 100
batch_size = 64 if torch.cuda.is_available() else 16
nfeat = 16
modelname = args.detections[:-4]+'_AE_weights.stdc'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lr = 0.003
wdL2 = 0.0
writer = SummaryWriter('runs/'+modelname)
vgg16 = models.vgg16.eval().to(device)
frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
encoder = models.sparrow_encoder(args.bottleneck)
decoder = models.sparrow_decoder(args.bottleneck, (4, 4) if args.nMel == 128 else (2, 4))
model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
optimizer = torch.optim.AdamW(model.parameters(), weight_decay=wdL2, lr=lr, betas=(0.8, 0.999))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch : .99**epoch)
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn)
loss_fun = torch.nn.MSELoss()
print('Go for model '+modelname)
step = 0
for epoch in range(nepoch):
for x, name in tqdm(loader, desc=str(epoch), leave=False):
optimizer.zero_grad()
label = frontend(x.to(device))
x = encoder(label)
pred = decoder(x)
predd = vgg16(pred.expand(pred.shape[0], 3, *pred.shape[2:]))
labell = vgg16(label.expand(label.shape[0], 3, *label.shape[2:]))
score = loss_fun(predd, labell)
score.backward()
optimizer.step()
writer.add_scalar('loss', score.item(), step)
if step%50==0 :
images = [(e-e.min())/(e.max()-e.min()) for e in label[:8]]
grid = make_grid(images)
writer.add_image('target', grid, step)
writer.add_embedding(x.detach(), global_step=step, label_img=label)
images = [(e-e.min())/(e.max()-e.min()) for e in pred[:8]]
grid = make_grid(images)
writer.add_image('reconstruct', grid, step)
step += 1
if epoch % 10 == 0:
scheduler.step()
torch.save(model.state_dict(), modelname)
import soundfile as sf
import torch
import numpy as np
from scipy.signal import resample
def collate_fn(batch):
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)
class Dataset(torch.utils.data.Dataset):
def __init__(self, df, audiopath, sr, sampleDur):
super(Dataset, self)
self.audiopath, self.df, self.sr, self.sampleDur = audiopath, df, sr, sampleDur
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
try:
info = sf.info(self.audiopath+'/'+row.filename)
dur, fs = info.duration, info.samplerate
start = int(np.clip(row.pos - self.sampleDur/2, 0, max(0, dur - self.sampleDur)) * fs)
sig, fs = sf.read(self.audiopath+'/'+row.filename, start=start, stop=start + int(self.sampleDur*fs), always_2d=True)
sig = sig[:,0]
except:
print(f'failed to load sound from row {row.name} with filename {row.filename}')
return None
if len(sig) < self.sampleDur * fs:
sig = np.pad(sig, int(self.sampleDur * fs - len(sig))//2+1, mode='reflect')[:int(self.sampleDur * fs)]
if fs != self.sr:
sig = resample(sig, int(len(sig)/fs*self.sr))
return torch.Tensor(norm(sig)).float(), row.name
def norm(arr):
return (arr - np.mean(arr) ) / np.std(arr)
class Flatten(torch.nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
return x.view(x.shape[0], -1)
class Reshape(torch.nn.Module):
def __init__(self, *shape):
super(Reshape, self).__init__()
self.shape = shape
def forward(self, x):
return x.view(x.shape[0], *self.shape)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment