Skip to content
Snippets Groups Projects
Commit 8f733869 authored by Paul Best's avatar Paul Best
Browse files

fixes

parent 197e07b2
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@ import models
import numpy as np, pandas as pd, torch
import umap
from tqdm import tqdm
import argparse
import argparse, os
torch.multiprocessing.set_sharing_strategy('file_system')
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute the auto-encoder embeddings of vocalizations once it was trained with train_AE.py")
......@@ -24,6 +24,7 @@ frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel
encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
model.load_state_dict(torch.load(args.modelname))
df = pd.read_csv(args.detections)
......@@ -40,4 +41,6 @@ encodings = np.stack(encodings)
print('Computing UMAP projections...')
X = umap.UMAP(n_jobs=-1).fit_transform(encodings)
np.save(f'encodings_{args.detections[:-4]}_{args.modelname.split(".")[0]}.npy', {'encodings':encodings, 'idx':idxs, 'umap':X})
out_fn = f'encodings_{os.path.basename(args.detections).rsplit(".",1)[0]}.npy'
print(f'Saving into {out_fn}')
np.save(out_fn, {'encodings':encodings, 'idx':idxs, 'umap':X})
......@@ -5,8 +5,7 @@ from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import torch, numpy as np, pandas as pd
from filterbank import STFT, MelFilter, MedFilt, Log1p
import hdbscan
import hdbscan, umap
import argparse
import models
try:
......@@ -22,6 +21,7 @@ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFo
For insights on how to tune HDBSCAN parameters, read https://hdbscan.readthedocs.io/en/latest/parameter_selection.html""")
parser.add_argument('encodings', type=str, help='.npy file containing umap projections and their associated index in the detection.pkl table (built using compute_embeddings.py)')
parser.add_argument('detections', type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
parser.add_argument('-umap_ndim', type=int, help="number of dimension for the UMAP compression", default=2)
parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
......@@ -29,80 +29,83 @@ parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spec
parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
parser.add_argument('-min_cluster_size', type=int, default=10, help='Used for HDBSCAN clustering.')
parser.add_argument('-channel', type=int, default=0)
parser.add_argument('-min_sample', type=int, default=5, help='Used for HDBSCAN clustering.')
parser.add_argument('-eps', type=float, default=0.05, help='Used for HDBSCAN clustering.')
parser.add_argument('-min_sample', type=int, default=3, help='Used for HDBSCAN clustering.')
parser.add_argument('-eps', type=float, default=0.01, help='Used for HDBSCAN clustering.')
args = parser.parse_args()
df = pd.read_csv(args.detections)
encodings = np.load(args.encodings, allow_pickle=True).item()
idxs, umap = encodings['idx'], encodings['umap']
df.loc[idxs, 'umap_x'] = umap[:,0]
df.loc[idxs, 'umap_y'] = umap[:,1]
idxs, umap_, embeddings = encodings['idx'], encodings['umap'], encodings['encodings']
frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
# Use HDBSCAN to cluster the embedings (min_cluster_size and min_samples parameters need to be tuned)
df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size,
min_samples=args.min_sample,
core_dist_n_jobs=-1,
cluster_selection_epsilon=args.eps,
cluster_selection_method='leaf').fit_predict(umap)
df.cluster = df.cluster.astype(int)
if args.umap_ndim == 2:
df.loc[idxs, 'umap_x'] = umap_[:,0]
df.loc[idxs, 'umap_y'] = umap_[:,1]
fs = 44100
frontend = torch.nn.Sequential(
STFT(2048, 256),
MelFilter(fs, 2048, 96, 500, 4000),
Log1p(4),
MedFilt()
)
# Use HDBSCAN to cluster the embedings (min_cluster_size and min_samples parameters need to be tuned)
df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size,
min_samples=args.min_sample,
core_dist_n_jobs=-1,
cluster_selection_epsilon=args.eps,
cluster_selection_method='leaf').fit_predict(umap_)
figscat = plt.figure(figsize=(10, 5))
plt.title(f'{args.encodings} {args.min_cluster_size} {args.min_sample} {args.eps}')
plt.scatter(umap_[:,0], umap_[:,1], s=3, alpha=.8, c=df.loc[idxs].cluster, cmap='tab20')
plt.tight_layout()
axScat = figscat.axes[0]
plt.savefig('projection')
figscat = plt.figure(figsize=(10, 5))
plt.title(f'{args.encodings} {args.min_cluster_size} {args.min_sample} {args.eps}')
plt.scatter(umap[:,0], umap[:,1], s=3, alpha=.8, c=df.loc[idxs].cluster, cmap='tab20')
plt.tight_layout()
axScat = figscat.axes[0]
figSpec = plt.figure()
plt.scatter(0, 0)
axSpec = figSpec.axes[0]
figSpec = plt.figure()
plt.scatter(0, 0)
axSpec = figSpec.axes[0]
#print(df.cluster.value_counts())
#print(df.cluster.value_counts())
class temp():
def __init__(self):
self.row = ""
def onclick(self, event):
# find the closest point to the mouse
left, right, bottom, top = axScat.get_xlim()[0], axScat.get_xlim()[1], axScat.get_ylim()[0], axScat.get_ylim()[1]
rangex, rangey = right - left, top - bottom
closest = (np.sqrt(((df.umap_x - event.xdata)/rangex)**2 + ((df.umap_y - event.ydata)/rangey)**2)).idxmin()
row = df.loc[closest]
# read waveform and compute spectrogram
info = sf.info(f'{args.audio_folder}/{row.filename}')
dur, fs = info.duration, info.samplerate
start = int(np.clip(row.pos - args.sampleDur/2, 0, dur - args.sampleDur) * fs)
sig, fs = sf.read(f'{args.audio_folder}/{row.filename}', start=start, stop=start + int(args.sampleDur*fs), always_2d=True)
sig = sig[:, args.channel]
if fs != args.SR:
sig = signal.resample(sig, int(len(sig)/fs*args.SR))
spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze()
axSpec.clear()
axSpec.imshow(spec, origin='lower', aspect='auto')
# Display and metadata management
axSpec.set_title(f'{closest}, cluster {row.cluster} ({(df.cluster==row.cluster).sum()} points)')
axScat.scatter(row.umap_x, row.umap_y, c='r')
axScat.set_xlim(left, right)
axScat.set_ylim(bottom, top)
figSpec.canvas.draw()
figscat.canvas.draw()
# Play the audio
if soundAvailable:
sd.play(sig, fs)
class temp():
def __init__(self):
self.row = ""
def onclick(self, event):
# find the closest point to the mouse
left, right, bottom, top = axScat.get_xlim()[0], axScat.get_xlim()[1], axScat.get_ylim()[0], axScat.get_ylim()[1]
rangex, rangey = right - left, top - bottom
closest = (np.sqrt(((df.umap_x - event.xdata)/rangex)**2 + ((df.umap_y - event.ydata)/rangey)**2)).idxmin()
row = df.loc[closest]
# read waveform and compute spectrogram
info = sf.info(f'{args.audio_folder}/{row.filename}')
dur, fs = info.duration, info.samplerate
start = int(np.clip(row.pos - args.sampleDur/2, 0, dur - args.sampleDur) * fs)
sig, fs = sf.read(f'{args.audio_folder}/{row.filename}', start=start, stop=start + int(args.sampleDur*fs), always_2d=True)
sig = sig[:, args.channel]
if fs != args.SR:
sig = signal.resample(sig, int(len(sig)/fs*args.SR))
spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze()
axSpec.clear()
axSpec.imshow(spec, origin='lower', aspect='auto')
# Display and metadata management
axSpec.set_title(f'{closest}, cluster {row.cluster:.0f} ({(df.cluster==row.cluster).sum()} points)')
axScat.scatter(row.umap_x, row.umap_y, c='r')
axScat.set_xlim(left, right)
axScat.set_ylim(bottom, top)
figSpec.canvas.draw()
figscat.canvas.draw()
# Play the audio
if soundAvailable:
sd.play(sig, fs)
mtemp = temp()
cid = figscat.canvas.mpl_connect('button_press_event', mtemp.onclick)
mtemp = temp()
cid = figscat.canvas.mpl_connect('button_press_event', mtemp.onclick)
plt.show()
else :
umap_ = umap.UMAP(n_jobs=-1, n_components=args.umap_ndim).fit_transform(embeddings)
df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size,
min_samples=args.min_sample,
core_dist_n_jobs=-1,
cluster_selection_epsilon=args.eps,
cluster_selection_method='leaf').fit_predict(umap_)
print(df.cluster.value_counts().describe())
plt.savefig('projection')
plt.show()
if input('\nType y to print cluster pngs.\n/!\ the cluster_pngs folder will be reset, backup if needed /!\ ') != 'y':
exit()
......@@ -112,11 +115,11 @@ os.system('rm -R cluster_pngs/*')
for c, grp in df.groupby('cluster'):
if c == -1 or len(grp) > 10_000:
continue
os.system('mkdir -p cluster_pngs/'+str(c))
loader = torch.utils.data.DataLoader(u.Dataset(grp.sample(min(len(grp), 200)), args.audio_folder, args.SR, args.sampleDur), batch_size=1, num_workers=8)
os.system(f'mkdir -p cluster_pngs/{c:.0f}')
loader = torch.utils.data.DataLoader(u.Dataset(grp.sample(min(len(grp), 200)), args.audio_folder, args.SR, args.sampleDur), batch_size=1, num_workers=8, collate_fn=u.collate_fn)
with torch.no_grad():
for x, idx in tqdm(loader, leave=False, desc=str(c)):
for x, idx in tqdm(loader, leave=False, desc=str(int(c))):
plt.imshow(frontend(x).squeeze().numpy(), origin='lower', aspect='auto')
plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
plt.savefig(f'cluster_pngs/{c}/{idx.squeeze().item()}')
plt.savefig(f'cluster_pngs/{c:.0f}/{idx.squeeze().item()}')
plt.close()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment