diff --git a/new_specie/compute_embeddings.py b/new_specie/compute_embeddings.py
index 93f344d5b9c1bcacd522451e75083b5c24e0c568..2b18520d395e79ce5e4ab74b1b6c6b1e024f309a 100755
--- a/new_specie/compute_embeddings.py
+++ b/new_specie/compute_embeddings.py
@@ -3,7 +3,7 @@ import models
 import numpy as np, pandas as pd, torch
 import umap
 from tqdm import tqdm
-import argparse
+import argparse, os
 torch.multiprocessing.set_sharing_strategy('file_system')
 
 parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute the auto-encoder embeddings of vocalizations once it was trained with train_AE.py")
@@ -24,6 +24,7 @@ frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel
 encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
 decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
 model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
+model.load_state_dict(torch.load(args.modelname))
 
 df = pd.read_csv(args.detections)
 
@@ -40,4 +41,6 @@ encodings = np.stack(encodings)
 
 print('Computing UMAP projections...')
 X = umap.UMAP(n_jobs=-1).fit_transform(encodings)
-np.save(f'encodings_{args.detections[:-4]}_{args.modelname.split(".")[0]}.npy', {'encodings':encodings, 'idx':idxs, 'umap':X})
+out_fn = f'encodings_{os.path.basename(args.detections).rsplit(".",1)[0]}.npy'
+print(f'Saving into {out_fn}')
+np.save(out_fn, {'encodings':encodings, 'idx':idxs, 'umap':X})
diff --git a/new_specie/sort_cluster.py b/new_specie/sort_cluster.py
index 56e5dc7705bf92248f9a53b9296fdd392b9646e6..aea8ada1b4876ba268ff6d16bfd83d26b434a5e6 100755
--- a/new_specie/sort_cluster.py
+++ b/new_specie/sort_cluster.py
@@ -5,8 +5,7 @@ from tqdm import tqdm
 import matplotlib.pyplot as plt
 import os
 import torch, numpy as np, pandas as pd
-from filterbank import STFT, MelFilter, MedFilt, Log1p
-import hdbscan
+import hdbscan, umap
 import argparse
 import models
 try:
@@ -22,6 +21,7 @@ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFo
     For insights on how to tune HDBSCAN parameters, read https://hdbscan.readthedocs.io/en/latest/parameter_selection.html""")
 parser.add_argument('encodings', type=str, help='.npy file containing umap projections and their associated index in the detection.pkl table (built using compute_embeddings.py)')
 parser.add_argument('detections', type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
+parser.add_argument('-umap_ndim', type=int, help="number of dimension for the UMAP compression", default=2)
 parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
 parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
 parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
@@ -29,80 +29,83 @@ parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spec
 parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
 parser.add_argument('-min_cluster_size', type=int, default=10, help='Used for HDBSCAN clustering.')
 parser.add_argument('-channel', type=int, default=0)
-parser.add_argument('-min_sample', type=int, default=5, help='Used for HDBSCAN clustering.')
-parser.add_argument('-eps', type=float, default=0.05, help='Used for HDBSCAN clustering.')
+parser.add_argument('-min_sample', type=int, default=3, help='Used for HDBSCAN clustering.')
+parser.add_argument('-eps', type=float, default=0.01, help='Used for HDBSCAN clustering.')
 args = parser.parse_args()
 
 df = pd.read_csv(args.detections)
 encodings = np.load(args.encodings, allow_pickle=True).item()
-idxs, umap = encodings['idx'], encodings['umap']
-df.loc[idxs, 'umap_x'] = umap[:,0]
-df.loc[idxs, 'umap_y'] = umap[:,1]
+idxs, umap_, embeddings = encodings['idx'], encodings['umap'], encodings['encodings']
+frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
 
-# Use HDBSCAN to cluster the embedings (min_cluster_size and min_samples parameters need to be tuned)
-df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size,
-                                min_samples=args.min_sample,
-                                core_dist_n_jobs=-1,
-                                cluster_selection_epsilon=args.eps,
-                                cluster_selection_method='leaf').fit_predict(umap)
-df.cluster = df.cluster.astype(int)
+if args.umap_ndim == 2:
+    df.loc[idxs, 'umap_x'] = umap_[:,0]
+    df.loc[idxs, 'umap_y'] = umap_[:,1]
 
-fs = 44100
-frontend = torch.nn.Sequential(
-  STFT(2048, 256),
-  MelFilter(fs, 2048, 96, 500, 4000),
-  Log1p(4),
-  MedFilt()
-)
+    # Use HDBSCAN to cluster the embedings (min_cluster_size and min_samples parameters need to be tuned)
+    df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size,
+                                    min_samples=args.min_sample,
+                                    core_dist_n_jobs=-1,
+                                    cluster_selection_epsilon=args.eps,
+                                    cluster_selection_method='leaf').fit_predict(umap_)
 
+    figscat = plt.figure(figsize=(10, 5))
+    plt.title(f'{args.encodings} {args.min_cluster_size} {args.min_sample} {args.eps}')
+    plt.scatter(umap_[:,0], umap_[:,1], s=3, alpha=.8, c=df.loc[idxs].cluster, cmap='tab20')
+    plt.tight_layout()
+    axScat = figscat.axes[0]
+    plt.savefig('projection')
 
-figscat = plt.figure(figsize=(10, 5))
-plt.title(f'{args.encodings} {args.min_cluster_size} {args.min_sample} {args.eps}')
-plt.scatter(umap[:,0], umap[:,1], s=3, alpha=.8, c=df.loc[idxs].cluster, cmap='tab20')
-plt.tight_layout()
-axScat = figscat.axes[0]
-figSpec = plt.figure()
-plt.scatter(0, 0)
-axSpec = figSpec.axes[0]
+    figSpec = plt.figure()
+    plt.scatter(0, 0)
+    axSpec = figSpec.axes[0]
 
-#print(df.cluster.value_counts())
+    #print(df.cluster.value_counts())
 
-class temp():
-    def __init__(self):
-        self.row = ""
-    def onclick(self, event):
-        # find the closest point to the mouse
-        left, right, bottom, top = axScat.get_xlim()[0], axScat.get_xlim()[1], axScat.get_ylim()[0], axScat.get_ylim()[1]
-        rangex, rangey =  right - left, top - bottom
-        closest = (np.sqrt(((df.umap_x - event.xdata)/rangex)**2 + ((df.umap_y  - event.ydata)/rangey)**2)).idxmin()
-        row = df.loc[closest]
-        # read waveform and compute spectrogram
-        info = sf.info(f'{args.audio_folder}/{row.filename}')
-        dur, fs = info.duration, info.samplerate
-        start = int(np.clip(row.pos - args.sampleDur/2, 0, dur - args.sampleDur) * fs)
-        sig, fs = sf.read(f'{args.audio_folder}/{row.filename}', start=start, stop=start + int(args.sampleDur*fs), always_2d=True)
-        sig = sig[:, args.channel]
-        if fs != args.SR:
-            sig = signal.resample(sig, int(len(sig)/fs*args.SR))
-        spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze()
-        axSpec.clear()
-        axSpec.imshow(spec, origin='lower', aspect='auto')
-        # Display and metadata management
-        axSpec.set_title(f'{closest}, cluster {row.cluster} ({(df.cluster==row.cluster).sum()} points)')
-        axScat.scatter(row.umap_x, row.umap_y, c='r')
-        axScat.set_xlim(left, right)
-        axScat.set_ylim(bottom, top)
-        figSpec.canvas.draw()
-        figscat.canvas.draw()
-        # Play the audio
-        if soundAvailable:
-            sd.play(sig, fs)
+    class temp():
+        def __init__(self):
+            self.row = ""
+        def onclick(self, event):
+            # find the closest point to the mouse
+            left, right, bottom, top = axScat.get_xlim()[0], axScat.get_xlim()[1], axScat.get_ylim()[0], axScat.get_ylim()[1]
+            rangex, rangey =  right - left, top - bottom
+            closest = (np.sqrt(((df.umap_x - event.xdata)/rangex)**2 + ((df.umap_y  - event.ydata)/rangey)**2)).idxmin()
+            row = df.loc[closest]
+            # read waveform and compute spectrogram
+            info = sf.info(f'{args.audio_folder}/{row.filename}')
+            dur, fs = info.duration, info.samplerate
+            start = int(np.clip(row.pos - args.sampleDur/2, 0, dur - args.sampleDur) * fs)
+            sig, fs = sf.read(f'{args.audio_folder}/{row.filename}', start=start, stop=start + int(args.sampleDur*fs), always_2d=True)
+            sig = sig[:, args.channel]
+            if fs != args.SR:
+                sig = signal.resample(sig, int(len(sig)/fs*args.SR))
+            spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze()
+            axSpec.clear()
+            axSpec.imshow(spec, origin='lower', aspect='auto')
+            # Display and metadata management
+            axSpec.set_title(f'{closest}, cluster {row.cluster:.0f} ({(df.cluster==row.cluster).sum()} points)')
+            axScat.scatter(row.umap_x, row.umap_y, c='r')
+            axScat.set_xlim(left, right)
+            axScat.set_ylim(bottom, top)
+            figSpec.canvas.draw()
+            figscat.canvas.draw()
+            # Play the audio
+            if soundAvailable:
+                sd.play(sig, fs)
 
-mtemp = temp()
-cid = figscat.canvas.mpl_connect('button_press_event', mtemp.onclick)
+    mtemp = temp()
+    cid = figscat.canvas.mpl_connect('button_press_event', mtemp.onclick)
+
+    plt.show()
+else :
+    umap_ = umap.UMAP(n_jobs=-1, n_components=args.umap_ndim).fit_transform(embeddings)
+    df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size,
+                                             min_samples=args.min_sample,
+                                             core_dist_n_jobs=-1,
+                                             cluster_selection_epsilon=args.eps,
+                                             cluster_selection_method='leaf').fit_predict(umap_)
+    print(df.cluster.value_counts().describe())
 
-plt.savefig('projection')
-plt.show()
 
 if input('\nType y to print cluster pngs.\n/!\ the cluster_pngs folder will be reset, backup if needed /!\ ') != 'y':
     exit()
@@ -112,11 +115,11 @@ os.system('rm -R cluster_pngs/*')
 for c, grp in df.groupby('cluster'):
     if c == -1 or len(grp) > 10_000:
         continue
-    os.system('mkdir -p cluster_pngs/'+str(c))
-    loader = torch.utils.data.DataLoader(u.Dataset(grp.sample(min(len(grp), 200)), args.audio_folder, args.SR, args.sampleDur), batch_size=1, num_workers=8)
+    os.system(f'mkdir -p cluster_pngs/{c:.0f}')
+    loader = torch.utils.data.DataLoader(u.Dataset(grp.sample(min(len(grp), 200)), args.audio_folder, args.SR, args.sampleDur), batch_size=1, num_workers=8, collate_fn=u.collate_fn)
     with torch.no_grad():
-        for x, idx in tqdm(loader, leave=False, desc=str(c)):
+        for x, idx in tqdm(loader, leave=False, desc=str(int(c))):
             plt.imshow(frontend(x).squeeze().numpy(), origin='lower', aspect='auto')
             plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
-            plt.savefig(f'cluster_pngs/{c}/{idx.squeeze().item()}')
+            plt.savefig(f'cluster_pngs/{c:.0f}/{idx.squeeze().item()}')
             plt.close()