Skip to content
Snippets Groups Projects
Commit 36a96885 authored by Paul Best's avatar Paul Best
Browse files
parents 726ea6dc 62e1ebe6
No related branches found
No related tags found
No related merge requests found
Scripts to train an auto-encoder and cluster animal vocalisations by frequency-contour similarity
Scripts to train an auto-encoder and cluster animal vocalisations by frequency-contour similarity.
Vocalisation detection needs to be done prior to this process (stored in a .csv file)
Scripts need to be called in the following order :
train_AE.py
compute_embeddings.py
sort_cluster.py
use `python myscript.py --help` to get more information on each scripts' usage
### Scripts need to be called in the following order:
If you want to train your own auto-encoder (optional since the generic one might suffice), use `python train_AE.py detections.csv`
Use trained auto-encoder weights to project your vocalisations with `python compute_embeddings.py generic_embedder.weights detections.csv`
Visualise vocalisation embeddings and resulting clusters with `python sort_cluster.py embeddings.npy detections.csv`
Use `python myscript.py --help` to get more information on each scripts' usage and options
required packages can be install using `pip install -r requirements.txt`
......@@ -4,9 +4,10 @@ import numpy as np, pandas as pd, torch
import umap
from tqdm import tqdm
import argparse
torch.multiprocessing.set_sharing_strategy('file_system')
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute the auto-encoder embeddings of vocalizations once it was trained with train_AE.py")
parser.add_argument('modelname', type=str, help='Filename of the AE weights (.stdc)')
parser.add_argument('modelname', type=str, help='Filename of the AE weights (.stdc or .weights)')
parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation")
......@@ -26,7 +27,7 @@ df = pd.read_csv(args.detections)
print('Computing AE projections...')
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur, channel=1), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8)
with torch.no_grad():
with torch.inference_mode():
encodings, idxs = [], []
for x, idx in tqdm(loader):
encoding = model[:2](x.to(device))
......@@ -37,4 +38,4 @@ encodings = np.stack(encodings)
print('Computing UMAP projections...')
X = umap.UMAP(n_jobs=-1).fit_transform(encodings)
np.save('encodings_'+args.modelname[:-4]+'npy', {'encodings':encodings, 'idx':idxs, 'umap':X})
np.save(f'encodings_{args.detections[:-4]}_{args.modelname.split('.')[0]}.npy', {'encodings':encodings, 'idx':idxs, 'umap':X})
......@@ -33,7 +33,7 @@ parser.add_argument('-min_sample', type=int, default=5, help='Used for HDBSCAN c
parser.add_argument('-eps', type=float, default=0.05, help='Used for HDBSCAN clustering.')
args = parser.parse_args()
df = pd.read_csv(args.detections, index_col=0)
df = pd.read_csv(args.detections)
encodings = np.load(args.encodings, allow_pickle=True).item()
idxs, umap = encodings['idx'], encodings['umap']
df.loc[idxs, 'umap_x'] = umap[:,0]
......
......@@ -40,7 +40,7 @@ loss_fun = torch.nn.MSELoss()
df = pd.read_csv(args.detections)
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0)), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn)
modelname = f'{args.detections[:-4]}_AE_{args.bottleneck}_mel{args.nMel}.stdc'
modelname = f'{args.detections[:-4]}_AE_{args.bottleneck}_mel{args.nMel}.weights'
step, writer = 0, SummaryWriter('runs/'+modelname)
print(f'Go for model {modelname} with {len(df)} vocalizations')
for epoch in range(100_000//len(loader)):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment