diff --git a/new_specie/compute_embeddings.py b/new_specie/compute_embeddings.py
index 87e7e5f24f117419e220e481a3e9b4a907234e3f..2970791dc7f11e3672f60d358c1ebbe9bea0eb8f 100755
--- a/new_specie/compute_embeddings.py
+++ b/new_specie/compute_embeddings.py
@@ -30,7 +30,7 @@ model.load_state_dict(torch.load(args.modelname))
 df = pd.read_csv(args.detections)
 
 print('Computing AE projections...')
-loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0), channel=1), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8, collate_fn=u.collate_fn)
+loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0)), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8, collate_fn=u.collate_fn)
 with torch.inference_mode():
     encodings, idxs = [], []
     for x, idx in tqdm(loader):
diff --git a/new_specie/print_annot.py b/new_specie/print_annot.py
index ce5591c7dcccb86d5018e01937ec46198fb45fbc..f178499b2f9d22e25bcd63ecd50204df10be0313 100755
--- a/new_specie/print_annot.py
+++ b/new_specie/print_annot.py
@@ -1,55 +1,33 @@
-import soundfile as sf
-import os
+import os, argparse
 from tqdm import tqdm
 import matplotlib.pyplot as plt
-import soundfile as sf
 import pandas as pd, numpy as np
 import models, utils as u
 import torch
-from filterbank import MelFilter, STFT, Log1p, MedFilt
 
-fs = 44_100
-sampleDur = 2.5
-nMel = 96
-hop = 260
-width = int((sampleDur * fs - 2048)/hop) + 1
-pad_width = int((1 * fs)/hop) + 1
 
-frontend = torch.nn.Sequential(
-  STFT(2048, hop),
-  MelFilter(fs, 2048, 96, 500, 4000),
-  Log1p(4),
-  MedFilt()
-)
+parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Plot vocalisation spectrograms into annot_pngs/")
+parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
+parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
+parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation")
+parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
+parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
+parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
+parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)")
+parser.set_defaults(feature=False)
+args = parser.parse_args()
 
-#df = pd.read_csv('S4A09100_20220527$045924.Table.1.selections.txt', delimiter='\t')
-#df['filename'] = 'S4A09100_20220527$045924.wav'
-#df['pos'] = df['Begin Time (s)'] + 1.25
-#df['dur'] = df['End Time (s)'] - df['Begin Time (s)']
-#df.drop(df[((df.dur < 2.5)|(df.dur > 8))].index, inplace=True)
-df = pd.read_csv('all_annot.csv')
-df = df['Cao'] = 'C'
-df['pos'] = df['Begin Time (s)'] + 1.25
 
-loader = torch.utils.data.DataLoader(u.Dataset(df, './', fs, 2.5 + 2), batch_size=1, num_workers=8, collate_fn=u.collate_fn)
+frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
+df = pd.read_csv(args.detections)
+
+loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0)), batch_size=1, num_workers=8, collate_fn=u.collate_fn, shuffle=True)
 
 for x, idx in tqdm(loader):
-    x = frontend(x).squeeze().detach()[:,pad_width:-pad_width]
+    x = frontend(x).squeeze().detach()
     plt.figure()
-    plt.imshow(x, origin='lower', aspect='auto')
+    plt.imshow(x, origin='lower', aspect='auto', vmin=torch.quantile(x, .25), cmap='Greys', vmax=torch.quantile(x, .98))
     plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
-    plt.savefig(f'gibbon_calls/cao_vit_DL/{idx.item()}')
+    plt.savefig(f'annot_pngs/{idx.item()}')
     plt.close()
 
-
-exit()
-
-sig, fs = sf.read('S4A09100_20220527$045924.wav')
-
-for i, r in tqdm(df.iterrows(), total=len(df)):
-#    x = frontend(torch.Tensor(sig[int(fs*r['Begin Time (s)']):int(fs*r['End Time (s)']), 1]).unsqueeze(0)).squeeze().detach()
-    x = frontend(torch.Tensor(sig[int(fs*r['Begin Time (s)']):int(fs*(r['Begin Time (s)']+2.5)), 1]).unsqueeze(0)).squeeze().detach()
-    plt.imshow(x, origin='lower', aspect='auto', extent=[0, r.dur, 500, 5000])
-    plt.tight_layout()
-    plt.savefig(f'gibbon_calls/cao_vit_cut/{i}')
-    plt.close()
diff --git a/new_specie/sort_cluster.py b/new_specie/sort_cluster.py
index aea8ada1b4876ba268ff6d16bfd83d26b434a5e6..87167693962eaa68cf84d76d3d773469ce39ac46 100755
--- a/new_specie/sort_cluster.py
+++ b/new_specie/sort_cluster.py
@@ -29,6 +29,7 @@ parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spec
 parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
 parser.add_argument('-min_cluster_size', type=int, default=10, help='Used for HDBSCAN clustering.')
 parser.add_argument('-channel', type=int, default=0)
+parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)")
 parser.add_argument('-min_sample', type=int, default=3, help='Used for HDBSCAN clustering.')
 parser.add_argument('-eps', type=float, default=0.01, help='Used for HDBSCAN clustering.')
 args = parser.parse_args()
@@ -36,7 +37,9 @@ args = parser.parse_args()
 df = pd.read_csv(args.detections)
 encodings = np.load(args.encodings, allow_pickle=True).item()
 idxs, umap_, embeddings = encodings['idx'], encodings['umap'], encodings['encodings']
-frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
+frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
+
+args.sampleDur += (2 if args.medfilt else 0)
 
 if args.umap_ndim == 2:
     df.loc[idxs, 'umap_x'] = umap_[:,0]
@@ -76,7 +79,7 @@ if args.umap_ndim == 2:
             dur, fs = info.duration, info.samplerate
             start = int(np.clip(row.pos - args.sampleDur/2, 0, dur - args.sampleDur) * fs)
             sig, fs = sf.read(f'{args.audio_folder}/{row.filename}', start=start, stop=start + int(args.sampleDur*fs), always_2d=True)
-            sig = sig[:, args.channel]
+            sig = sig[:, row.Channel -1 if 'Channel' in row else args.channel]
             if fs != args.SR:
                 sig = signal.resample(sig, int(len(sig)/fs*args.SR))
             spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze()
@@ -107,11 +110,16 @@ else :
     print(df.cluster.value_counts().describe())
 
 
+if df.cluster.isna().sum() == 0:
+    df.cluster = df.cluster.astype(int)
+
 if input('\nType y to print cluster pngs.\n/!\ the cluster_pngs folder will be reset, backup if needed /!\ ') != 'y':
     exit()
 
 os.system('rm -R cluster_pngs/*')
 
+df.to_csv(f'cluster_pngs/{args.detections}', index=False)
+
 for c, grp in df.groupby('cluster'):
     if c == -1 or len(grp) > 10_000:
         continue