diff --git a/new_specie/compute_embeddings.py b/new_specie/compute_embeddings.py index 6e17b83cf15958fa546002632e2c94c421abeafd..b25717cd09b095360ea5bd010ea1645ec949ebae 100755 --- a/new_specie/compute_embeddings.py +++ b/new_specie/compute_embeddings.py @@ -5,8 +5,7 @@ import umap from tqdm import tqdm import argparse -parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, \ - description="Compute the auto-encoder embeddings of vocalizations once it was trained with train_AE.py") +parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute the auto-encoder embeddings of vocalizations once it was trained with train_AE.py") parser.add_argument('modelname', type=str, help='Filename of the AE weights (.stdc)') parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed") parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files") @@ -19,10 +18,9 @@ args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) -encoder = models.sparrow_encoder(args.bottleneck) -decoder = models.sparrow_decoder(args.bottleneck, (4, 4) if args.nMel == 128 else (2, 4)) +encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4)) +decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4)) model = torch.nn.Sequential(frontend, encoder, decoder).to(device) -model.load_state_dict(torch.load(args.modelname, map_location=device)) df = pd.read_csv(args.detections) diff --git a/new_specie/models.py b/new_specie/models.py index 715e73db1b7ee43a11861b571c5a35f22d2dbe40..48d7cf84ec44fc40247731fe0c8b15cbbb54e559 100755 --- a/new_specie/models.py +++ b/new_specie/models.py @@ -13,16 +13,17 @@ for nm, mod in vgg16.named_modules(): frontend = lambda sr, nfft, sampleDur, n_mel : nn.Sequential( STFT(nfft, int((sampleDur*sr - nfft)/128)), MelFilter(sr, nfft, n_mel, sr//nfft, sr//2), - Log1p(7, trainable=False) + Log1p(7, trainable=False), + nn.InstanceNorm2d(1), + u.Croper2D(n_mel, 128) ) -sparrow_encoder = lambda nfeat : nn.Sequential( +sparrow_encoder = lambda nfeat, shape : nn.Sequential( nn.Conv2d(1, 32, 3, stride=2, bias=False, padding=(1)), nn.BatchNorm2d(32), nn.LeakyReLU(0.01), nn.Conv2d(32, 64, 3, stride=2, bias=False, padding=1), nn.BatchNorm2d(64), - nn.MaxPool2d((1, 2)), nn.ReLU(True), nn.Conv2d(64, 128, 3, stride=2, bias=False, padding=1), nn.BatchNorm2d(128), @@ -31,17 +32,15 @@ sparrow_encoder = lambda nfeat : nn.Sequential( nn.BatchNorm2d(256), nn.ReLU(True), nn.Conv2d(256, nfeat, (3, 5), stride=2, padding=(1, 2)), - nn.AdaptiveMaxPool2d((1,1)), - u.Reshape(nfeat) + u.Reshape(nfeat * shape[0] * shape[1]) ) sparrow_decoder = lambda nfeat, shape : nn.Sequential( - nn.Linear(nfeat, nfeat*shape[0]*shape[1]), - u.Reshape(nfeat, *shape), + u.Reshape(nfeat//(shape[0]*shape[1]), *shape), nn.ReLU(True), nn.Upsample(scale_factor=2), - nn.Conv2d(nfeat, 256, (3, 3), bias=False, padding=1), + nn.Conv2d(nfeat//(shape[0]*shape[1]), 256, (3, 3), bias=False, padding=1), nn.BatchNorm2d(256), nn.ReLU(True), nn.Conv2d(256, 256, (3, 3), bias=False, padding=1), diff --git a/new_specie/sort_cluster.py b/new_specie/sort_cluster.py index ba26d2f25cf632cf94d72028b52926107bcf3319..6ab1a705ffff57582999dfc6986d2c7f3105cd6d 100755 --- a/new_specie/sort_cluster.py +++ b/new_specie/sort_cluster.py @@ -1,3 +1,5 @@ +import soundfile as sf +from scipy import signal import utils as u from tqdm import tqdm import matplotlib.pyplot as plt @@ -5,6 +7,12 @@ import os import torch, numpy as np, pandas as pd import hdbscan import argparse +import models +try: + import sounddevice as sd + soundAvailable = True +except: + soundAvailable = False parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, \ description="""Interface to visualize projected vocalizations (UMAP reduced AE embeddings), and tune HDBSCAN parameters.\n @@ -13,18 +21,23 @@ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFo For insights on how to tune HDBSCAN parameters, read https://hdbscan.readthedocs.io/en/latest/parameter_selection.html""") parser.add_argument('encodings', type=str, help='.npy file containing umap projections and their associated index in the detection.pkl table (built using compute_embeddings.py)') parser.add_argument('detections', type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed") -parser.add_argument('audio_folder', type=str, help='Path to the folder with complete audio files') +#parser.add_argument('audio_folder', type=str, help='Path to the folder with complete audio files') +parser.add_argument("-audio_folder", type=str, default='', help="Folder from which to load sound files") parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation") +parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)") +parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation") parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded") -parser.add_argument('-min_cluster_size', type=int, default=50, help='Used for HDBSCAN clustering.') +parser.add_argument('-min_cluster_size', type=int, default=10, help='Used for HDBSCAN clustering.') parser.add_argument('-min_sample', type=int, default=5, help='Used for HDBSCAN clustering.') -parser.add_argument('-eps', type=float, default=0.0, help='Used for HDBSCAN clustering.') +parser.add_argument('-eps', type=float, default=0.05, help='Used for HDBSCAN clustering.') args = parser.parse_args() -gpu = torch.device('cuda' if torch.cuda.is_available() else 'cpu') df = pd.read_csv(args.detections) encodings = np.load(args.encodings, allow_pickle=True).item() idxs, umap = encodings['idx'], encodings['umap'] +df.loc[idxs, 'umap_x'] = umap[:,0] +df.loc[idxs, 'umap_y'] = umap[:,1] + # Use HDBSCAN to cluster the embedings (min_cluster_size and min_samples parameters need to be tuned) df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size, min_samples=args.min_sample, @@ -33,10 +46,53 @@ df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size, cluster_selection_method='leaf').fit_predict(umap) df.cluster = df.cluster.astype(int) +frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) + figscat = plt.figure(figsize=(10, 5)) plt.title(f'{args.encodings} {args.min_cluster_size} {args.min_sample} {args.eps}') -plt.scatter(umap[:,0], umap[:,1], s=3, alpha=.2, c=df.loc[idxs].cluster, cmap='tab20') +plt.scatter(umap[:,0], umap[:,1], s=3, alpha=.8, c=df.loc[idxs].cluster, cmap='tab20') plt.tight_layout() +axScat = figscat.axes[0] +figSpec = plt.figure() +plt.scatter(0, 0) +axSpec = figSpec.axes[0] + +#print(df.cluster.value_counts()) + +class temp(): + def __init__(self): + self.row = "" + def onclick(self, event): + # find the closest point to the mouse + left, right, bottom, top = axScat.get_xlim()[0], axScat.get_xlim()[1], axScat.get_ylim()[0], axScat.get_ylim()[1] + rangex, rangey = right - left, top - bottom + closest = (np.sqrt(((df.umap_x - event.xdata)/rangex)**2 + ((df.umap_y - event.ydata)/rangey)**2)).idxmin() + row = df.loc[closest] + # read waveform and compute spectrogram + info = sf.info(f'{args.audio_folder}/{row.filename}') + dur, fs = info.duration, info.samplerate + start = int(np.clip(row.pos - args.sampleDur/2, 0, dur - args.sampleDur) * fs) + sig, fs = sf.read(f'{args.audio_folder}/{row.filename}', start=start, stop=start + int(args.sampleDur*fs), always_2d=True) + sig = sig[:,0] + if fs != args.SR: + sig = signal.resample(sig, int(len(sig)/fs*args.SR)) + spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze() + axSpec.clear() + axSpec.imshow(spec, origin='lower', aspect='auto') + # Display and metadata management + axSpec.set_title(f'{closest}, cluster {row.cluster} ({(df.cluster==row.cluster).sum()} points)') + axScat.scatter(row.umap_x, row.umap_y, c='r') + axScat.set_xlim(left, right) + axScat.set_ylim(bottom, top) + figSpec.canvas.draw() + figscat.canvas.draw() + # Play the audio + if soundAvailable: + sd.play(sig*10, fs) + +mtemp = temp() +cid = figscat.canvas.mpl_connect('button_press_event', mtemp.onclick) + plt.savefig('projection') plt.show() diff --git a/new_specie/train_AE.py b/new_specie/train_AE.py index 9138d3fbc85a28cc2302bc6bb319020369b7f70c..15f63ebe99eb00e0d2a958b21476a1c60e6a673d 100755 --- a/new_specie/train_AE.py +++ b/new_specie/train_AE.py @@ -6,9 +6,10 @@ from tqdm import tqdm from torch.utils.tensorboard import SummaryWriter import argparse -parser = argparse.ArgumentParser(description="""This script trains an auto-encoder to compress and depcompress vocalisation spectrograms.\n - Reconstruction quality can be monitored via tensorboard ($tensorboard --logdir=runs/ --bind_all)""") -parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed") +parser = argparse.ArgumentParser(description="""This script trains an auto-encoder to compress and depcompress vocalisation spectrograms. + Reconstruction quality can be monitored via tensorboard ($tensorboard --logdir=runs/ --bind_all)""", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. A column \'filename\' (path of the soundfile) and a column \'pos\{ (center of the detection in seconds) are needed") parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files") parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation") parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)") @@ -17,47 +18,43 @@ parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signa parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck') args = parser.parse_args() -df = pd.read_csv(args.detections) -print(f'Training using {len(df)} vocalizations') - -nepoch = 100 -batch_size = 64 if torch.cuda.is_available() else 16 -nfeat = 16 -modelname = args.detections[:-4]+'_AE_weights.stdc' +# init AE architecture device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -lr = 0.003 -wdL2 = 0.0 -writer = SummaryWriter('runs/'+modelname) -vgg16 = models.vgg16.eval().to(device) - +assert args.nMel % 32 == 0 and args.nMel > 0, "nMel argument should be a multiple of 32" +assert args.bottleneck % (args.nMel//32 * 4) == 0, "Bottleneck size must be a multiple of the last volume\'s size (nMel//32 * 4)" frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) -encoder = models.sparrow_encoder(args.bottleneck) -decoder = models.sparrow_decoder(args.bottleneck, (4, 4) if args.nMel == 128 else (2, 4)) +encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4)) +decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4)) model = torch.nn.Sequential(frontend, encoder, decoder).to(device) - +# training / optimisation setup +lr, wdL2, batch_size = 0.003, 0.0, 64 if torch.cuda.is_available() else 16 optimizer = torch.optim.AdamW(model.parameters(), weight_decay=wdL2, lr=lr, betas=(0.8, 0.999)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch : .99**epoch) -loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn) +vgg16 = models.vgg16.eval().to(device) loss_fun = torch.nn.MSELoss() -print('Go for model '+modelname) -step = 0 -for epoch in range(nepoch): +# data loader +df = pd.read_csv(args.detections) +loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn) + +modelname = f'{args.detections[:-4]}_AE_{args.bottleneck}_mel{args.nMel}.stdc' +step, writer = 0, SummaryWriter('runs/'+modelname) +print(f'Go for model {modelname} with {len(df)} vocalizations') +for epoch in range(100_000//len(loader)): for x, name in tqdm(loader, desc=str(epoch), leave=False): optimizer.zero_grad() label = frontend(x.to(device)) x = encoder(label) pred = decoder(x) - predd = vgg16(pred.expand(pred.shape[0], 3, *pred.shape[2:])) - labell = vgg16(label.expand(label.shape[0], 3, *label.shape[2:])) + vgg_pred = vgg16(pred.expand(pred.shape[0], 3, *pred.shape[2:])) + vgg_label = vgg16(label.expand(label.shape[0], 3, *label.shape[2:])) - score = loss_fun(predd, labell) + score = loss_fun(vgg_pred, vgg_label) score.backward() optimizer.step() writer.add_scalar('loss', score.item(), step) - if step%50==0 : images = [(e-e.min())/(e.max()-e.min()) for e in label[:8]] grid = make_grid(images) @@ -68,7 +65,6 @@ for epoch in range(nepoch): writer.add_image('reconstruct', grid, step) step += 1 - if epoch % 10 == 0: - scheduler.step() + if step % 500 == 0: + scheduler.step() torch.save(model.state_dict(), modelname) - diff --git a/new_specie/utils.py b/new_specie/utils.py index 8439833f8e2c7e2de8cb079b281e30bc1298a579..5c6644f6546275418b03cf9741321b4de758a0d6 100755 --- a/new_specie/utils.py +++ b/new_specie/utils.py @@ -1,14 +1,15 @@ import soundfile as sf -import torch +from torch import nn, Tensor +from torch.utils.data import Dataset, dataloader import numpy as np from scipy.signal import resample def collate_fn(batch): batch = list(filter(lambda x: x is not None, batch)) - return torch.utils.data.dataloader.default_collate(batch) + return dataloader.default_collate(batch) -class Dataset(torch.utils.data.Dataset): +class Dataset(Dataset): def __init__(self, df, audiopath, sr, sampleDur): super(Dataset, self) self.audiopath, self.df, self.sr, self.sampleDur = audiopath, df, sr, sampleDur @@ -25,30 +26,34 @@ class Dataset(torch.utils.data.Dataset): sig, fs = sf.read(self.audiopath+'/'+row.filename, start=start, stop=start + int(self.sampleDur*fs), always_2d=True) sig = sig[:,0] except: - print(f'failed to load sound from row {row.name} with filename {row.filename}') + print(f'Failed to load sound from row {row.name} with filename {row.filename}') return None if len(sig) < self.sampleDur * fs: - sig = np.pad(sig, int(self.sampleDur * fs - len(sig))//2+1, mode='reflect')[:int(self.sampleDur * fs)] + sig = np.concatenate([sig, np.zeros(int(self.sampleDur * fs) - len(sig))]) if fs != self.sr: sig = resample(sig, int(len(sig)/fs*self.sr)) - return torch.Tensor(norm(sig)).float(), row.name + return Tensor(norm(sig)).float(), row.name def norm(arr): return (arr - np.mean(arr) ) / np.std(arr) - -class Flatten(torch.nn.Module): +class Flatten(nn.Module): def __init__(self): super(Flatten, self).__init__() def forward(self, x): return x.view(x.shape[0], -1) - -class Reshape(torch.nn.Module): +class Reshape(nn.Module): def __init__(self, *shape): super(Reshape, self).__init__() self.shape = shape - def forward(self, x): return x.view(x.shape[0], *self.shape) + +class Croper2D(nn.Module): + def __init__(self, *shape): + super(Croper2D, self).__init__() + self.shape = shape + def forward(self, x): + return x[:,:,:self.shape[0],:self.shape[1]]