diff --git a/new_specie/compute_embeddings.py b/new_specie/compute_embeddings.py
index 6e17b83cf15958fa546002632e2c94c421abeafd..b25717cd09b095360ea5bd010ea1645ec949ebae 100755
--- a/new_specie/compute_embeddings.py
+++ b/new_specie/compute_embeddings.py
@@ -5,8 +5,7 @@ import umap
 from tqdm import tqdm
 import argparse
 
-parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, \
-    description="Compute the auto-encoder embeddings of vocalizations once it was trained with train_AE.py")
+parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Compute the auto-encoder embeddings of vocalizations once it was trained with train_AE.py")
 parser.add_argument('modelname', type=str, help='Filename of the AE weights (.stdc)')
 parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
 parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
@@ -19,10 +18,9 @@ args = parser.parse_args()
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
-encoder = models.sparrow_encoder(args.bottleneck)
-decoder = models.sparrow_decoder(args.bottleneck, (4, 4) if args.nMel == 128 else (2, 4))
+encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
+decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
 model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
-model.load_state_dict(torch.load(args.modelname, map_location=device))
 
 df = pd.read_csv(args.detections)
 
diff --git a/new_specie/models.py b/new_specie/models.py
index 715e73db1b7ee43a11861b571c5a35f22d2dbe40..48d7cf84ec44fc40247731fe0c8b15cbbb54e559 100755
--- a/new_specie/models.py
+++ b/new_specie/models.py
@@ -13,16 +13,17 @@ for nm, mod in vgg16.named_modules():
 frontend = lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
   STFT(nfft, int((sampleDur*sr - nfft)/128)),
   MelFilter(sr, nfft, n_mel, sr//nfft, sr//2),
-  Log1p(7, trainable=False)
+  Log1p(7, trainable=False),
+  nn.InstanceNorm2d(1),
+  u.Croper2D(n_mel, 128)
 )
 
-sparrow_encoder = lambda nfeat : nn.Sequential(
+sparrow_encoder = lambda nfeat, shape : nn.Sequential(
   nn.Conv2d(1, 32, 3, stride=2, bias=False, padding=(1)),
   nn.BatchNorm2d(32),
   nn.LeakyReLU(0.01),
   nn.Conv2d(32, 64, 3, stride=2, bias=False, padding=1),
   nn.BatchNorm2d(64),
-  nn.MaxPool2d((1, 2)),
   nn.ReLU(True),
   nn.Conv2d(64, 128, 3, stride=2, bias=False, padding=1),
   nn.BatchNorm2d(128),
@@ -31,17 +32,15 @@ sparrow_encoder = lambda nfeat : nn.Sequential(
   nn.BatchNorm2d(256),
   nn.ReLU(True),
   nn.Conv2d(256, nfeat, (3, 5), stride=2, padding=(1, 2)),
-  nn.AdaptiveMaxPool2d((1,1)),
-  u.Reshape(nfeat)
+  u.Reshape(nfeat * shape[0] * shape[1])
 )
 
 sparrow_decoder = lambda nfeat, shape : nn.Sequential(
-  nn.Linear(nfeat, nfeat*shape[0]*shape[1]),
-  u.Reshape(nfeat, *shape),
+  u.Reshape(nfeat//(shape[0]*shape[1]), *shape),
   nn.ReLU(True),
 
   nn.Upsample(scale_factor=2),
-  nn.Conv2d(nfeat, 256, (3, 3), bias=False, padding=1),
+  nn.Conv2d(nfeat//(shape[0]*shape[1]), 256, (3, 3), bias=False, padding=1),
   nn.BatchNorm2d(256),
   nn.ReLU(True),
   nn.Conv2d(256, 256, (3, 3), bias=False, padding=1),
diff --git a/new_specie/sort_cluster.py b/new_specie/sort_cluster.py
index ba26d2f25cf632cf94d72028b52926107bcf3319..6ab1a705ffff57582999dfc6986d2c7f3105cd6d 100755
--- a/new_specie/sort_cluster.py
+++ b/new_specie/sort_cluster.py
@@ -1,3 +1,5 @@
+import soundfile as sf
+from scipy import signal
 import utils as u
 from tqdm import tqdm
 import matplotlib.pyplot as plt
@@ -5,6 +7,12 @@ import os
 import torch, numpy as np, pandas as pd
 import hdbscan
 import argparse
+import models
+try:
+    import sounddevice as sd
+    soundAvailable = True
+except:
+    soundAvailable = False
 
 parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, \
     description="""Interface to visualize projected vocalizations (UMAP reduced AE embeddings), and tune HDBSCAN parameters.\n
@@ -13,18 +21,23 @@ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFo
     For insights on how to tune HDBSCAN parameters, read https://hdbscan.readthedocs.io/en/latest/parameter_selection.html""")
 parser.add_argument('encodings', type=str, help='.npy file containing umap projections and their associated index in the detection.pkl table (built using compute_embeddings.py)')
 parser.add_argument('detections', type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
-parser.add_argument('audio_folder', type=str, help='Path to the folder with complete audio files')
+#parser.add_argument('audio_folder', type=str, help='Path to the folder with complete audio files')
+parser.add_argument("-audio_folder", type=str, default='', help="Folder from which to load sound files")
 parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
+parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
+parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation")
 parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
-parser.add_argument('-min_cluster_size', type=int, default=50, help='Used for HDBSCAN clustering.')
+parser.add_argument('-min_cluster_size', type=int, default=10, help='Used for HDBSCAN clustering.')
 parser.add_argument('-min_sample', type=int, default=5, help='Used for HDBSCAN clustering.')
-parser.add_argument('-eps', type=float, default=0.0, help='Used for HDBSCAN clustering.')
+parser.add_argument('-eps', type=float, default=0.05, help='Used for HDBSCAN clustering.')
 args = parser.parse_args()
 
-gpu = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 df = pd.read_csv(args.detections)
 encodings = np.load(args.encodings, allow_pickle=True).item()
 idxs, umap = encodings['idx'], encodings['umap']
+df.loc[idxs, 'umap_x'] = umap[:,0]
+df.loc[idxs, 'umap_y'] = umap[:,1]
+
 # Use HDBSCAN to cluster the embedings (min_cluster_size and min_samples parameters need to be tuned)
 df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size,
                                 min_samples=args.min_sample,
@@ -33,10 +46,53 @@ df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size,
                                 cluster_selection_method='leaf').fit_predict(umap)
 df.cluster = df.cluster.astype(int)
 
+frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
+
 figscat = plt.figure(figsize=(10, 5))
 plt.title(f'{args.encodings} {args.min_cluster_size} {args.min_sample} {args.eps}')
-plt.scatter(umap[:,0], umap[:,1], s=3, alpha=.2, c=df.loc[idxs].cluster, cmap='tab20')
+plt.scatter(umap[:,0], umap[:,1], s=3, alpha=.8, c=df.loc[idxs].cluster, cmap='tab20')
 plt.tight_layout()
+axScat = figscat.axes[0]
+figSpec = plt.figure()
+plt.scatter(0, 0)
+axSpec = figSpec.axes[0]
+
+#print(df.cluster.value_counts())
+
+class temp():
+    def __init__(self):
+        self.row = ""
+    def onclick(self, event):
+        # find the closest point to the mouse
+        left, right, bottom, top = axScat.get_xlim()[0], axScat.get_xlim()[1], axScat.get_ylim()[0], axScat.get_ylim()[1]
+        rangex, rangey =  right - left, top - bottom
+        closest = (np.sqrt(((df.umap_x - event.xdata)/rangex)**2 + ((df.umap_y  - event.ydata)/rangey)**2)).idxmin()
+        row = df.loc[closest]
+        # read waveform and compute spectrogram
+        info = sf.info(f'{args.audio_folder}/{row.filename}')
+        dur, fs = info.duration, info.samplerate
+        start = int(np.clip(row.pos - args.sampleDur/2, 0, dur - args.sampleDur) * fs)
+        sig, fs = sf.read(f'{args.audio_folder}/{row.filename}', start=start, stop=start + int(args.sampleDur*fs), always_2d=True)
+        sig = sig[:,0]
+        if fs != args.SR:
+            sig = signal.resample(sig, int(len(sig)/fs*args.SR))
+        spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze()
+        axSpec.clear()
+        axSpec.imshow(spec, origin='lower', aspect='auto')
+        # Display and metadata management
+        axSpec.set_title(f'{closest}, cluster {row.cluster} ({(df.cluster==row.cluster).sum()} points)')
+        axScat.scatter(row.umap_x, row.umap_y, c='r')
+        axScat.set_xlim(left, right)
+        axScat.set_ylim(bottom, top)
+        figSpec.canvas.draw()
+        figscat.canvas.draw()
+        # Play the audio
+        if soundAvailable:
+            sd.play(sig*10, fs)
+
+mtemp = temp()
+cid = figscat.canvas.mpl_connect('button_press_event', mtemp.onclick)
+
 plt.savefig('projection')
 plt.show()
 
diff --git a/new_specie/train_AE.py b/new_specie/train_AE.py
index 9138d3fbc85a28cc2302bc6bb319020369b7f70c..15f63ebe99eb00e0d2a958b21476a1c60e6a673d 100755
--- a/new_specie/train_AE.py
+++ b/new_specie/train_AE.py
@@ -6,9 +6,10 @@ from tqdm import tqdm
 from torch.utils.tensorboard import SummaryWriter
 import argparse
 
-parser = argparse.ArgumentParser(description="""This script trains an auto-encoder to compress and depcompress vocalisation spectrograms.\n
-            Reconstruction quality can be monitored via tensorboard ($tensorboard --logdir=runs/ --bind_all)""")
-parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
+parser = argparse.ArgumentParser(description="""This script trains an auto-encoder to compress and depcompress vocalisation spectrograms.
+                                 Reconstruction quality can be monitored via tensorboard ($tensorboard --logdir=runs/ --bind_all)""",
+                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. A column \'filename\' (path of the soundfile) and a column \'pos\{ (center of the detection in seconds) are needed")
 parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
 parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation")
 parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
@@ -17,47 +18,43 @@ parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signa
 parser.add_argument("-bottleneck", type=int, default=16, help='size of the auto-encoder\'s bottleneck')
 args = parser.parse_args()
 
-df = pd.read_csv(args.detections)
-print(f'Training using {len(df)} vocalizations')
-
-nepoch = 100
-batch_size = 64 if torch.cuda.is_available() else 16
-nfeat = 16
-modelname = args.detections[:-4]+'_AE_weights.stdc'
+# init AE architecture
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-lr = 0.003
-wdL2 = 0.0
-writer = SummaryWriter('runs/'+modelname)
-vgg16 = models.vgg16.eval().to(device)
-
+assert args.nMel % 32 == 0 and args.nMel > 0, "nMel argument should be a multiple of 32"
+assert args.bottleneck % (args.nMel//32 * 4) == 0, "Bottleneck size must be a multiple of the last volume\'s size (nMel//32 * 4)"
 frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
-encoder = models.sparrow_encoder(args.bottleneck)
-decoder = models.sparrow_decoder(args.bottleneck, (4, 4) if args.nMel == 128 else (2, 4))
+encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
+decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
 model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
 
-
+# training / optimisation setup
+lr, wdL2, batch_size = 0.003, 0.0, 64 if torch.cuda.is_available() else 16
 optimizer = torch.optim.AdamW(model.parameters(), weight_decay=wdL2, lr=lr, betas=(0.8, 0.999))
 scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch : .99**epoch)
-loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn)
+vgg16 = models.vgg16.eval().to(device)
 loss_fun = torch.nn.MSELoss()
 
-print('Go for model '+modelname)
-step = 0
-for epoch in range(nepoch):
+# data loader
+df = pd.read_csv(args.detections)
+loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=u.collate_fn)
+
+modelname = f'{args.detections[:-4]}_AE_{args.bottleneck}_mel{args.nMel}.stdc'
+step, writer = 0, SummaryWriter('runs/'+modelname)
+print(f'Go for model {modelname} with {len(df)} vocalizations')
+for epoch in range(100_000//len(loader)):
     for x, name in tqdm(loader, desc=str(epoch), leave=False):
         optimizer.zero_grad()
         label = frontend(x.to(device))
         x = encoder(label)
         pred = decoder(x)
-        predd = vgg16(pred.expand(pred.shape[0], 3, *pred.shape[2:]))
-        labell = vgg16(label.expand(label.shape[0], 3, *label.shape[2:]))
+        vgg_pred = vgg16(pred.expand(pred.shape[0], 3, *pred.shape[2:]))
+        vgg_label = vgg16(label.expand(label.shape[0], 3, *label.shape[2:]))
 
-        score = loss_fun(predd, labell)
+        score = loss_fun(vgg_pred, vgg_label)
         score.backward()
         optimizer.step()
         writer.add_scalar('loss', score.item(), step)
 
-
         if step%50==0 :
             images = [(e-e.min())/(e.max()-e.min()) for e in label[:8]]
             grid = make_grid(images)
@@ -68,7 +65,6 @@ for epoch in range(nepoch):
             writer.add_image('reconstruct', grid, step)
 
         step += 1
-    if epoch % 10 == 0:
-        scheduler.step()
+        if step % 500 == 0:
+            scheduler.step()
     torch.save(model.state_dict(), modelname)
-
diff --git a/new_specie/utils.py b/new_specie/utils.py
index 8439833f8e2c7e2de8cb079b281e30bc1298a579..5c6644f6546275418b03cf9741321b4de758a0d6 100755
--- a/new_specie/utils.py
+++ b/new_specie/utils.py
@@ -1,14 +1,15 @@
 import soundfile as sf
-import torch
+from torch import nn, Tensor
+from torch.utils.data import Dataset, dataloader
 import numpy as np
 from scipy.signal import resample
 
 
 def collate_fn(batch):
     batch = list(filter(lambda x: x is not None, batch))
-    return torch.utils.data.dataloader.default_collate(batch)
+    return dataloader.default_collate(batch)
 
-class Dataset(torch.utils.data.Dataset):
+class Dataset(Dataset):
     def __init__(self, df, audiopath, sr, sampleDur):
         super(Dataset, self)
         self.audiopath, self.df, self.sr, self.sampleDur = audiopath, df, sr, sampleDur
@@ -25,30 +26,34 @@ class Dataset(torch.utils.data.Dataset):
             sig, fs = sf.read(self.audiopath+'/'+row.filename, start=start, stop=start + int(self.sampleDur*fs), always_2d=True)
             sig = sig[:,0]
         except:
-            print(f'failed to load sound from row {row.name} with filename {row.filename}')
+            print(f'Failed to load sound from row {row.name} with filename {row.filename}')
             return None
         if len(sig) < self.sampleDur * fs:
-            sig = np.pad(sig, int(self.sampleDur * fs - len(sig))//2+1, mode='reflect')[:int(self.sampleDur * fs)]
+            sig = np.concatenate([sig, np.zeros(int(self.sampleDur * fs) - len(sig))])
         if fs != self.sr:
             sig = resample(sig, int(len(sig)/fs*self.sr))
-        return torch.Tensor(norm(sig)).float(), row.name
+        return Tensor(norm(sig)).float(), row.name
 
 
 def norm(arr):
     return (arr - np.mean(arr) ) / np.std(arr)
 
-
-class Flatten(torch.nn.Module):
+class Flatten(nn.Module):
     def __init__(self):
         super(Flatten, self).__init__()
     def forward(self, x):
         return x.view(x.shape[0], -1)
 
-
-class Reshape(torch.nn.Module):
+class Reshape(nn.Module):
     def __init__(self, *shape):
         super(Reshape, self).__init__()
         self.shape = shape
-
     def forward(self, x):
         return x.view(x.shape[0], *self.shape)
+
+class Croper2D(nn.Module):
+    def __init__(self, *shape):
+        super(Croper2D, self).__init__()
+        self.shape = shape
+    def forward(self, x):
+        return x[:,:,:self.shape[0],:self.shape[1]]