Skip to content
Snippets Groups Projects
Commit 8f733869 authored by Paul Best's avatar Paul Best
Browse files

fixes

parent 197e07b2
No related branches found
No related tags found
No related merge requests found
...@@ -3,7 +3,7 @@ import models ...@@ -3,7 +3,7 @@ import models
import numpy as np, pandas as pd, torch import numpy as np, pandas as pd, torch
import umap import umap
from tqdm import tqdm from tqdm import tqdm
import argparse import argparse, os
torch.multiprocessing.set_sharing_strategy('file_system') torch.multiprocessing.set_sharing_strategy('file_system')
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute the auto-encoder embeddings of vocalizations once it was trained with train_AE.py") parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute the auto-encoder embeddings of vocalizations once it was trained with train_AE.py")
...@@ -24,6 +24,7 @@ frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel ...@@ -24,6 +24,7 @@ frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel
encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4)) encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4)) decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
model = torch.nn.Sequential(frontend, encoder, decoder).to(device) model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
model.load_state_dict(torch.load(args.modelname))
df = pd.read_csv(args.detections) df = pd.read_csv(args.detections)
...@@ -40,4 +41,6 @@ encodings = np.stack(encodings) ...@@ -40,4 +41,6 @@ encodings = np.stack(encodings)
print('Computing UMAP projections...') print('Computing UMAP projections...')
X = umap.UMAP(n_jobs=-1).fit_transform(encodings) X = umap.UMAP(n_jobs=-1).fit_transform(encodings)
np.save(f'encodings_{args.detections[:-4]}_{args.modelname.split(".")[0]}.npy', {'encodings':encodings, 'idx':idxs, 'umap':X}) out_fn = f'encodings_{os.path.basename(args.detections).rsplit(".",1)[0]}.npy'
print(f'Saving into {out_fn}')
np.save(out_fn, {'encodings':encodings, 'idx':idxs, 'umap':X})
...@@ -5,8 +5,7 @@ from tqdm import tqdm ...@@ -5,8 +5,7 @@ from tqdm import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import os import os
import torch, numpy as np, pandas as pd import torch, numpy as np, pandas as pd
from filterbank import STFT, MelFilter, MedFilt, Log1p import hdbscan, umap
import hdbscan
import argparse import argparse
import models import models
try: try:
...@@ -22,6 +21,7 @@ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFo ...@@ -22,6 +21,7 @@ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFo
For insights on how to tune HDBSCAN parameters, read https://hdbscan.readthedocs.io/en/latest/parameter_selection.html""") For insights on how to tune HDBSCAN parameters, read https://hdbscan.readthedocs.io/en/latest/parameter_selection.html""")
parser.add_argument('encodings', type=str, help='.npy file containing umap projections and their associated index in the detection.pkl table (built using compute_embeddings.py)') parser.add_argument('encodings', type=str, help='.npy file containing umap projections and their associated index in the detection.pkl table (built using compute_embeddings.py)')
parser.add_argument('detections', type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed") parser.add_argument('detections', type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
parser.add_argument('-umap_ndim', type=int, help="number of dimension for the UMAP compression", default=2)
parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files") parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation") parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)") parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
...@@ -29,80 +29,83 @@ parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spec ...@@ -29,80 +29,83 @@ parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spec
parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded") parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
parser.add_argument('-min_cluster_size', type=int, default=10, help='Used for HDBSCAN clustering.') parser.add_argument('-min_cluster_size', type=int, default=10, help='Used for HDBSCAN clustering.')
parser.add_argument('-channel', type=int, default=0) parser.add_argument('-channel', type=int, default=0)
parser.add_argument('-min_sample', type=int, default=5, help='Used for HDBSCAN clustering.') parser.add_argument('-min_sample', type=int, default=3, help='Used for HDBSCAN clustering.')
parser.add_argument('-eps', type=float, default=0.05, help='Used for HDBSCAN clustering.') parser.add_argument('-eps', type=float, default=0.01, help='Used for HDBSCAN clustering.')
args = parser.parse_args() args = parser.parse_args()
df = pd.read_csv(args.detections) df = pd.read_csv(args.detections)
encodings = np.load(args.encodings, allow_pickle=True).item() encodings = np.load(args.encodings, allow_pickle=True).item()
idxs, umap = encodings['idx'], encodings['umap'] idxs, umap_, embeddings = encodings['idx'], encodings['umap'], encodings['encodings']
df.loc[idxs, 'umap_x'] = umap[:,0] frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
df.loc[idxs, 'umap_y'] = umap[:,1]
# Use HDBSCAN to cluster the embedings (min_cluster_size and min_samples parameters need to be tuned) if args.umap_ndim == 2:
df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size, df.loc[idxs, 'umap_x'] = umap_[:,0]
min_samples=args.min_sample, df.loc[idxs, 'umap_y'] = umap_[:,1]
core_dist_n_jobs=-1,
cluster_selection_epsilon=args.eps,
cluster_selection_method='leaf').fit_predict(umap)
df.cluster = df.cluster.astype(int)
fs = 44100 # Use HDBSCAN to cluster the embedings (min_cluster_size and min_samples parameters need to be tuned)
frontend = torch.nn.Sequential( df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size,
STFT(2048, 256), min_samples=args.min_sample,
MelFilter(fs, 2048, 96, 500, 4000), core_dist_n_jobs=-1,
Log1p(4), cluster_selection_epsilon=args.eps,
MedFilt() cluster_selection_method='leaf').fit_predict(umap_)
)
figscat = plt.figure(figsize=(10, 5))
plt.title(f'{args.encodings} {args.min_cluster_size} {args.min_sample} {args.eps}')
plt.scatter(umap_[:,0], umap_[:,1], s=3, alpha=.8, c=df.loc[idxs].cluster, cmap='tab20')
plt.tight_layout()
axScat = figscat.axes[0]
plt.savefig('projection')
figscat = plt.figure(figsize=(10, 5)) figSpec = plt.figure()
plt.title(f'{args.encodings} {args.min_cluster_size} {args.min_sample} {args.eps}') plt.scatter(0, 0)
plt.scatter(umap[:,0], umap[:,1], s=3, alpha=.8, c=df.loc[idxs].cluster, cmap='tab20') axSpec = figSpec.axes[0]
plt.tight_layout()
axScat = figscat.axes[0]
figSpec = plt.figure()
plt.scatter(0, 0)
axSpec = figSpec.axes[0]
#print(df.cluster.value_counts()) #print(df.cluster.value_counts())
class temp(): class temp():
def __init__(self): def __init__(self):
self.row = "" self.row = ""
def onclick(self, event): def onclick(self, event):
# find the closest point to the mouse # find the closest point to the mouse
left, right, bottom, top = axScat.get_xlim()[0], axScat.get_xlim()[1], axScat.get_ylim()[0], axScat.get_ylim()[1] left, right, bottom, top = axScat.get_xlim()[0], axScat.get_xlim()[1], axScat.get_ylim()[0], axScat.get_ylim()[1]
rangex, rangey = right - left, top - bottom rangex, rangey = right - left, top - bottom
closest = (np.sqrt(((df.umap_x - event.xdata)/rangex)**2 + ((df.umap_y - event.ydata)/rangey)**2)).idxmin() closest = (np.sqrt(((df.umap_x - event.xdata)/rangex)**2 + ((df.umap_y - event.ydata)/rangey)**2)).idxmin()
row = df.loc[closest] row = df.loc[closest]
# read waveform and compute spectrogram # read waveform and compute spectrogram
info = sf.info(f'{args.audio_folder}/{row.filename}') info = sf.info(f'{args.audio_folder}/{row.filename}')
dur, fs = info.duration, info.samplerate dur, fs = info.duration, info.samplerate
start = int(np.clip(row.pos - args.sampleDur/2, 0, dur - args.sampleDur) * fs) start = int(np.clip(row.pos - args.sampleDur/2, 0, dur - args.sampleDur) * fs)
sig, fs = sf.read(f'{args.audio_folder}/{row.filename}', start=start, stop=start + int(args.sampleDur*fs), always_2d=True) sig, fs = sf.read(f'{args.audio_folder}/{row.filename}', start=start, stop=start + int(args.sampleDur*fs), always_2d=True)
sig = sig[:, args.channel] sig = sig[:, args.channel]
if fs != args.SR: if fs != args.SR:
sig = signal.resample(sig, int(len(sig)/fs*args.SR)) sig = signal.resample(sig, int(len(sig)/fs*args.SR))
spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze() spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze()
axSpec.clear() axSpec.clear()
axSpec.imshow(spec, origin='lower', aspect='auto') axSpec.imshow(spec, origin='lower', aspect='auto')
# Display and metadata management # Display and metadata management
axSpec.set_title(f'{closest}, cluster {row.cluster} ({(df.cluster==row.cluster).sum()} points)') axSpec.set_title(f'{closest}, cluster {row.cluster:.0f} ({(df.cluster==row.cluster).sum()} points)')
axScat.scatter(row.umap_x, row.umap_y, c='r') axScat.scatter(row.umap_x, row.umap_y, c='r')
axScat.set_xlim(left, right) axScat.set_xlim(left, right)
axScat.set_ylim(bottom, top) axScat.set_ylim(bottom, top)
figSpec.canvas.draw() figSpec.canvas.draw()
figscat.canvas.draw() figscat.canvas.draw()
# Play the audio # Play the audio
if soundAvailable: if soundAvailable:
sd.play(sig, fs) sd.play(sig, fs)
mtemp = temp() mtemp = temp()
cid = figscat.canvas.mpl_connect('button_press_event', mtemp.onclick) cid = figscat.canvas.mpl_connect('button_press_event', mtemp.onclick)
plt.show()
else :
umap_ = umap.UMAP(n_jobs=-1, n_components=args.umap_ndim).fit_transform(embeddings)
df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size,
min_samples=args.min_sample,
core_dist_n_jobs=-1,
cluster_selection_epsilon=args.eps,
cluster_selection_method='leaf').fit_predict(umap_)
print(df.cluster.value_counts().describe())
plt.savefig('projection')
plt.show()
if input('\nType y to print cluster pngs.\n/!\ the cluster_pngs folder will be reset, backup if needed /!\ ') != 'y': if input('\nType y to print cluster pngs.\n/!\ the cluster_pngs folder will be reset, backup if needed /!\ ') != 'y':
exit() exit()
...@@ -112,11 +115,11 @@ os.system('rm -R cluster_pngs/*') ...@@ -112,11 +115,11 @@ os.system('rm -R cluster_pngs/*')
for c, grp in df.groupby('cluster'): for c, grp in df.groupby('cluster'):
if c == -1 or len(grp) > 10_000: if c == -1 or len(grp) > 10_000:
continue continue
os.system('mkdir -p cluster_pngs/'+str(c)) os.system(f'mkdir -p cluster_pngs/{c:.0f}')
loader = torch.utils.data.DataLoader(u.Dataset(grp.sample(min(len(grp), 200)), args.audio_folder, args.SR, args.sampleDur), batch_size=1, num_workers=8) loader = torch.utils.data.DataLoader(u.Dataset(grp.sample(min(len(grp), 200)), args.audio_folder, args.SR, args.sampleDur), batch_size=1, num_workers=8, collate_fn=u.collate_fn)
with torch.no_grad(): with torch.no_grad():
for x, idx in tqdm(loader, leave=False, desc=str(c)): for x, idx in tqdm(loader, leave=False, desc=str(int(c))):
plt.imshow(frontend(x).squeeze().numpy(), origin='lower', aspect='auto') plt.imshow(frontend(x).squeeze().numpy(), origin='lower', aspect='auto')
plt.subplots_adjust(top=1, bottom=0, left=0, right=1) plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
plt.savefig(f'cluster_pngs/{c}/{idx.squeeze().item()}') plt.savefig(f'cluster_pngs/{c:.0f}/{idx.squeeze().item()}')
plt.close() plt.close()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment