diff --git a/new_specie/compute_embeddings.py b/new_specie/compute_embeddings.py index 93f344d5b9c1bcacd522451e75083b5c24e0c568..2b18520d395e79ce5e4ab74b1b6c6b1e024f309a 100755 --- a/new_specie/compute_embeddings.py +++ b/new_specie/compute_embeddings.py @@ -3,7 +3,7 @@ import models import numpy as np, pandas as pd, torch import umap from tqdm import tqdm -import argparse +import argparse, os torch.multiprocessing.set_sharing_strategy('file_system') parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute the auto-encoder embeddings of vocalizations once it was trained with train_AE.py") @@ -24,6 +24,7 @@ frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4)) decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4)) model = torch.nn.Sequential(frontend, encoder, decoder).to(device) +model.load_state_dict(torch.load(args.modelname)) df = pd.read_csv(args.detections) @@ -40,4 +41,6 @@ encodings = np.stack(encodings) print('Computing UMAP projections...') X = umap.UMAP(n_jobs=-1).fit_transform(encodings) -np.save(f'encodings_{args.detections[:-4]}_{args.modelname.split(".")[0]}.npy', {'encodings':encodings, 'idx':idxs, 'umap':X}) +out_fn = f'encodings_{os.path.basename(args.detections).rsplit(".",1)[0]}.npy' +print(f'Saving into {out_fn}') +np.save(out_fn, {'encodings':encodings, 'idx':idxs, 'umap':X}) diff --git a/new_specie/sort_cluster.py b/new_specie/sort_cluster.py index 56e5dc7705bf92248f9a53b9296fdd392b9646e6..aea8ada1b4876ba268ff6d16bfd83d26b434a5e6 100755 --- a/new_specie/sort_cluster.py +++ b/new_specie/sort_cluster.py @@ -5,8 +5,7 @@ from tqdm import tqdm import matplotlib.pyplot as plt import os import torch, numpy as np, pandas as pd -from filterbank import STFT, MelFilter, MedFilt, Log1p -import hdbscan +import hdbscan, umap import argparse import models try: @@ -22,6 +21,7 @@ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFo For insights on how to tune HDBSCAN parameters, read https://hdbscan.readthedocs.io/en/latest/parameter_selection.html""") parser.add_argument('encodings', type=str, help='.npy file containing umap projections and their associated index in the detection.pkl table (built using compute_embeddings.py)') parser.add_argument('detections', type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed") +parser.add_argument('-umap_ndim', type=int, help="number of dimension for the UMAP compression", default=2) parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files") parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation") parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)") @@ -29,80 +29,83 @@ parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spec parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded") parser.add_argument('-min_cluster_size', type=int, default=10, help='Used for HDBSCAN clustering.') parser.add_argument('-channel', type=int, default=0) -parser.add_argument('-min_sample', type=int, default=5, help='Used for HDBSCAN clustering.') -parser.add_argument('-eps', type=float, default=0.05, help='Used for HDBSCAN clustering.') +parser.add_argument('-min_sample', type=int, default=3, help='Used for HDBSCAN clustering.') +parser.add_argument('-eps', type=float, default=0.01, help='Used for HDBSCAN clustering.') args = parser.parse_args() df = pd.read_csv(args.detections) encodings = np.load(args.encodings, allow_pickle=True).item() -idxs, umap = encodings['idx'], encodings['umap'] -df.loc[idxs, 'umap_x'] = umap[:,0] -df.loc[idxs, 'umap_y'] = umap[:,1] +idxs, umap_, embeddings = encodings['idx'], encodings['umap'], encodings['encodings'] +frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) -# Use HDBSCAN to cluster the embedings (min_cluster_size and min_samples parameters need to be tuned) -df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size, - min_samples=args.min_sample, - core_dist_n_jobs=-1, - cluster_selection_epsilon=args.eps, - cluster_selection_method='leaf').fit_predict(umap) -df.cluster = df.cluster.astype(int) +if args.umap_ndim == 2: + df.loc[idxs, 'umap_x'] = umap_[:,0] + df.loc[idxs, 'umap_y'] = umap_[:,1] -fs = 44100 -frontend = torch.nn.Sequential( - STFT(2048, 256), - MelFilter(fs, 2048, 96, 500, 4000), - Log1p(4), - MedFilt() -) + # Use HDBSCAN to cluster the embedings (min_cluster_size and min_samples parameters need to be tuned) + df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size, + min_samples=args.min_sample, + core_dist_n_jobs=-1, + cluster_selection_epsilon=args.eps, + cluster_selection_method='leaf').fit_predict(umap_) + figscat = plt.figure(figsize=(10, 5)) + plt.title(f'{args.encodings} {args.min_cluster_size} {args.min_sample} {args.eps}') + plt.scatter(umap_[:,0], umap_[:,1], s=3, alpha=.8, c=df.loc[idxs].cluster, cmap='tab20') + plt.tight_layout() + axScat = figscat.axes[0] + plt.savefig('projection') -figscat = plt.figure(figsize=(10, 5)) -plt.title(f'{args.encodings} {args.min_cluster_size} {args.min_sample} {args.eps}') -plt.scatter(umap[:,0], umap[:,1], s=3, alpha=.8, c=df.loc[idxs].cluster, cmap='tab20') -plt.tight_layout() -axScat = figscat.axes[0] -figSpec = plt.figure() -plt.scatter(0, 0) -axSpec = figSpec.axes[0] + figSpec = plt.figure() + plt.scatter(0, 0) + axSpec = figSpec.axes[0] -#print(df.cluster.value_counts()) + #print(df.cluster.value_counts()) -class temp(): - def __init__(self): - self.row = "" - def onclick(self, event): - # find the closest point to the mouse - left, right, bottom, top = axScat.get_xlim()[0], axScat.get_xlim()[1], axScat.get_ylim()[0], axScat.get_ylim()[1] - rangex, rangey = right - left, top - bottom - closest = (np.sqrt(((df.umap_x - event.xdata)/rangex)**2 + ((df.umap_y - event.ydata)/rangey)**2)).idxmin() - row = df.loc[closest] - # read waveform and compute spectrogram - info = sf.info(f'{args.audio_folder}/{row.filename}') - dur, fs = info.duration, info.samplerate - start = int(np.clip(row.pos - args.sampleDur/2, 0, dur - args.sampleDur) * fs) - sig, fs = sf.read(f'{args.audio_folder}/{row.filename}', start=start, stop=start + int(args.sampleDur*fs), always_2d=True) - sig = sig[:, args.channel] - if fs != args.SR: - sig = signal.resample(sig, int(len(sig)/fs*args.SR)) - spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze() - axSpec.clear() - axSpec.imshow(spec, origin='lower', aspect='auto') - # Display and metadata management - axSpec.set_title(f'{closest}, cluster {row.cluster} ({(df.cluster==row.cluster).sum()} points)') - axScat.scatter(row.umap_x, row.umap_y, c='r') - axScat.set_xlim(left, right) - axScat.set_ylim(bottom, top) - figSpec.canvas.draw() - figscat.canvas.draw() - # Play the audio - if soundAvailable: - sd.play(sig, fs) + class temp(): + def __init__(self): + self.row = "" + def onclick(self, event): + # find the closest point to the mouse + left, right, bottom, top = axScat.get_xlim()[0], axScat.get_xlim()[1], axScat.get_ylim()[0], axScat.get_ylim()[1] + rangex, rangey = right - left, top - bottom + closest = (np.sqrt(((df.umap_x - event.xdata)/rangex)**2 + ((df.umap_y - event.ydata)/rangey)**2)).idxmin() + row = df.loc[closest] + # read waveform and compute spectrogram + info = sf.info(f'{args.audio_folder}/{row.filename}') + dur, fs = info.duration, info.samplerate + start = int(np.clip(row.pos - args.sampleDur/2, 0, dur - args.sampleDur) * fs) + sig, fs = sf.read(f'{args.audio_folder}/{row.filename}', start=start, stop=start + int(args.sampleDur*fs), always_2d=True) + sig = sig[:, args.channel] + if fs != args.SR: + sig = signal.resample(sig, int(len(sig)/fs*args.SR)) + spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze() + axSpec.clear() + axSpec.imshow(spec, origin='lower', aspect='auto') + # Display and metadata management + axSpec.set_title(f'{closest}, cluster {row.cluster:.0f} ({(df.cluster==row.cluster).sum()} points)') + axScat.scatter(row.umap_x, row.umap_y, c='r') + axScat.set_xlim(left, right) + axScat.set_ylim(bottom, top) + figSpec.canvas.draw() + figscat.canvas.draw() + # Play the audio + if soundAvailable: + sd.play(sig, fs) -mtemp = temp() -cid = figscat.canvas.mpl_connect('button_press_event', mtemp.onclick) + mtemp = temp() + cid = figscat.canvas.mpl_connect('button_press_event', mtemp.onclick) + + plt.show() +else : + umap_ = umap.UMAP(n_jobs=-1, n_components=args.umap_ndim).fit_transform(embeddings) + df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size, + min_samples=args.min_sample, + core_dist_n_jobs=-1, + cluster_selection_epsilon=args.eps, + cluster_selection_method='leaf').fit_predict(umap_) + print(df.cluster.value_counts().describe()) -plt.savefig('projection') -plt.show() if input('\nType y to print cluster pngs.\n/!\ the cluster_pngs folder will be reset, backup if needed /!\ ') != 'y': exit() @@ -112,11 +115,11 @@ os.system('rm -R cluster_pngs/*') for c, grp in df.groupby('cluster'): if c == -1 or len(grp) > 10_000: continue - os.system('mkdir -p cluster_pngs/'+str(c)) - loader = torch.utils.data.DataLoader(u.Dataset(grp.sample(min(len(grp), 200)), args.audio_folder, args.SR, args.sampleDur), batch_size=1, num_workers=8) + os.system(f'mkdir -p cluster_pngs/{c:.0f}') + loader = torch.utils.data.DataLoader(u.Dataset(grp.sample(min(len(grp), 200)), args.audio_folder, args.SR, args.sampleDur), batch_size=1, num_workers=8, collate_fn=u.collate_fn) with torch.no_grad(): - for x, idx in tqdm(loader, leave=False, desc=str(c)): + for x, idx in tqdm(loader, leave=False, desc=str(int(c))): plt.imshow(frontend(x).squeeze().numpy(), origin='lower', aspect='auto') plt.subplots_adjust(top=1, bottom=0, left=0, right=1) - plt.savefig(f'cluster_pngs/{c}/{idx.squeeze().item()}') + plt.savefig(f'cluster_pngs/{c:.0f}/{idx.squeeze().item()}') plt.close()