diff --git a/.gitignore b/.gitignore
index d4acccc40c48fb2183901bcb00260a82f5810f71..76ccee7b16358c65b6c72e2d84db0c85d666f015 100755
--- a/.gitignore
+++ b/.gitignore
@@ -6,7 +6,8 @@
 *__pycache__
 *log
 humpback2/annots
-
+gibbon
+new_specie/*/
 otter/pone.0112562.s003.xlsx
 zebra_finch/Library_notes.pdf
 annot_distrib.pdf
diff --git a/plot_annot_distrib.py b/plot_annot_distrib.py
index fc0433b3742c14ed2b6295b980b60dcea319a94e..c79e709de411ccc9a322552db712c0bba07f0f25 100755
--- a/plot_annot_distrib.py
+++ b/plot_annot_distrib.py
@@ -11,6 +11,7 @@ info = {
     'zebra_finch': ['zebra finch', 'elie2018zebra', 'bird'],
     'otter': ['otter', '', ''],
     'humpback': ['humpback whale', 'malige2021use', 'cetacean'],
+    'humpback2':['humpback whale', 'malige2021use', 'cetacean'],
     'dolphin': ['bottlenose dolphin', 'sayigh2022sarasota', 'cetacean']
 }
 
diff --git a/plot_clusters.py b/plot_clusters.py
old mode 100644
new mode 100755
index 91d494f4548bb74d4157409bc43c0b2ff56039a4..e71515f124b9a60d94eefea46c14738fa542e5e2
--- a/plot_clusters.py
+++ b/plot_clusters.py
@@ -5,7 +5,7 @@ import models, utils as u
 
 
 species = np.loadtxt('good_species.txt', dtype=str)
-fig, ax = plt.subplots(nrows=len(species), figsize=(7, 10))
+fig, ax = plt.subplots(nrows=len(species), figsize=(7, 10), dpi=200)
 for i, specie in enumerate(species):
     meta = models.meta[specie]
     frontend = models.frontend['pcenMel'](meta['sr'], meta['nfft'], meta['sampleDur'], 128)
@@ -14,15 +14,18 @@ for i, specie in enumerate(species):
     df = pd.read_csv(f'{specie}/{specie}.csv')
     clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.05, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
     df.loc[idxs, 'cluster'] = clusters.astype(int)
-    for j, cluster in enumerate(np.random.choice(np.arange(df.cluster.max()), 5)):
-        for k, (x, name) in enumerate(torch.utils.data.DataLoader(u.Dataset(df[df.cluster==cluster].sample(10), f'{specie}/audio/', meta['sr'], meta['sampleDur']), batch_size=1)):
-            ax[i].imshow(frontend(x).squeeze().numpy(), extent=[k, k+1, j, j+1], origin='lower', aspect='auto')
+    for j, cluster in enumerate(np.random.choice(np.arange(df.cluster.max()), 4)):
+        for k, (x, name) in enumerate(torch.utils.data.DataLoader(u.Dataset(df[df.cluster==cluster].sample(8), f'{specie}/audio/', meta['sr'], meta['sampleDur']), batch_size=1)):
+            spec = frontend(x).squeeze().numpy()
+            ax[i].imshow(spec, extent=[k, k+1, j, j+1], origin='lower', aspect='auto', cmap='Greys', vmin=np.quantile(spec, .2), vmax=np.quantile(spec, .98))
     ax[i].set_xticks([])
     ax[i].set_yticks([])
     # ax[i].grid(color='w', xdata=np.arange(1, 10), ydata=np.arange(1, 5))
     ax[i].set_ylabel(specie.replace('_', ' '))
-    ax[i].set_xlim(0, 10)
-    ax[i].set_ylim(0, 5)
-
+    ax[i].set_xlim(0, 8)
+    ax[i].set_ylim(0, 4)
+    ax[i].vlines(np.arange(1, 8), 0, 4, linewidths=1, color='black')
+    ax[i].hlines(np.arange(1, 4), 0, 8, linewidths=1, color='black')
+plt.subplots_adjust(wspace=0.1)
 plt.tight_layout()
 plt.savefig('clusters.pdf')
\ No newline at end of file
diff --git a/plot_hdbscan_HP.py b/plot_hdbscan_HP.py
deleted file mode 100755
index ff0ac01dae8e9d335105d75d2696a4c9152dcf70..0000000000000000000000000000000000000000
--- a/plot_hdbscan_HP.py
+++ /dev/null
@@ -1,44 +0,0 @@
-import matplotlib.pyplot as plt
-import pandas as pd, numpy as np
-
-
-frontends = ['biosound', 'vggish', '256_logMel128', '256_logSTFT', '256_Mel128', '256_pcenMel128', '512_pcenMel128', '128_pcenMel128', '64_pcenMel128', '32_pcenMel128', '16_pcenMel128']
-
-fn = open('hdbscan_HP.txt', 'r')
-
-out = []
-for l in fn.readlines():
-    l = l[:-1].split(' ')
-    if len(l)==2:
-        specie, frontend = l[0], l[1]
-    else:
-        out.append({'specie':specie, 'frontend':frontend, 'nmi':float(l[0]), 'mcs':int(l[1]), 'ms': l[2], 'eps': float(l[3]), 'al': l[4]})
-df = pd.DataFrame().from_dict(out)
-
-df.ms = df.ms.replace('None', 0).astype(int)
-
-df.to_csv('hdbscan_HP3.csv', index=False)
-
-best = df.loc[df.groupby(["specie", 'frontend']).nmi.idxmax()]
-res = [(conf, (grp.set_index(['specie', 'frontend']).nmi / best.set_index(['specie', 'frontend']).nmi).sum()) for conf, grp in df.groupby(['mcs', 'ms', 'eps', 'al'])]
-conf = res[np.argmax([r[1] for r in res])]
-print('best HP', conf)
-conf = conf[0]
-
-fig, ax = plt.subplots(ncols = 2, figsize=(10, 5), sharex=True, sharey=True)
-for s, grp in df.groupby('specie'):
-    ax[0].scatter([grp[grp.frontend==f].nmi.max() for f in frontends], np.arange(len(frontends)))
-ax[0].grid()
-ax[1].grid()
-ax[0].set_yticks(np.arange(len(frontends)))
-ax[0].set_yticklabels(frontends)
-for s, grp in df[(( df.mcs==conf[0] )&( df.ms==conf[1] )&( df.eps==conf[2] )&( df.al==conf[3] ))].groupby('specie'):
-    ax[1].scatter([grp[grp.frontend==f].nmi.max() for f in frontends], np.arange(len(frontends)), label=s)
-ax[1].legend(bbox_to_anchor=(1,1))
-plt.tight_layout()
-ax[0].set_title('Best HDBSCAN settings')
-ax[1].set_title('Shared HDBSCAN settings')
-ax[0].set_xlabel('NMI')
-ax[1].set_xlabel('NMI')
-plt.tight_layout()
-plt.savefig('NMIs_hdbscan.pdf')
diff --git a/plot_main_results.py b/plot_main_results.py
new file mode 100755
index 0000000000000000000000000000000000000000..75de2a853666619347d35ae92f30b20c939fd857
--- /dev/null
+++ b/plot_main_results.py
@@ -0,0 +1,31 @@
+import matplotlib.pyplot as plt
+import pandas as pd, numpy as np
+
+all_frontends = [['256_logMel128', '256_logMel128_noprcptl', 'spec32', 'biosound'][::-1],
+                 ['256_logMel128', 'allbut_AE_256_logMel128', 'openl3', 'wav2vec2', 'crepe'][::-1],
+                 ['256_logMel128', '256_Mel128', '256_pcenMel128', '256_logSTFT']]
+
+all_frontendNames = [['AE prcptl', 'AE MSE', 'Spectro.', 'PAFs'][::-1],
+                     ['AE', 'gen. AE', 'OpenL3', 'Wav2Vec2', 'CREPE'][::-1],
+                     ['log-Mel', 'Mel', 'PCEN-Mel', 'log-STFT']]
+
+all_plotNames = ['handcraft', 'deepembed', 'frontends']
+for frontends, frontendNames, plotName in zip(all_frontends, all_frontendNames, all_plotNames):
+    df = pd.read_csv('hdbscan_HP.csv')
+    df.loc[df.ms.isna(), 'ms'] = 0
+    best = df.loc[df.groupby(["specie", 'frontend']).nmi.idxmax()]
+    res = [(conf, (grp.set_index(['specie', 'frontend']).nmi / best.set_index(['specie', 'frontend']).nmi).sum()) for conf, grp in df.groupby(['ncomp', 'mcs', 'ms', 'eps', 'al'])]
+    conf = res[np.argmax([r[1] for r in res])]
+    print('best HP', conf)
+    conf = conf[0]
+
+    df = df[((df.ncomp == conf[0])&(df.mcs == conf[1] )&( df.ms==conf[2] )&( df.eps==conf[3] )&(df.al==conf[4] ))]
+    df.specie = df.specie.replace('humpback2', 'humpback\n(small)')
+    species = df.specie.unique()
+    out = pd.DataFrame({fn:[df[((df.specie==s)&(df.frontend==f))].iloc[0].nmi for s in species] for f, fn in zip(frontends, frontendNames)}, index=[s.replace('_','\n') for s in species])
+    out.plot.bar(figsize=(9, 3), rot=0)
+    plt.legend(bbox_to_anchor=(1,1))
+    plt.ylabel('NMI')
+    plt.grid(axis='y')
+    plt.tight_layout()
+    plt.savefig(f'NMIs_hdbscan_barplot_{plotName}.pdf')
diff --git a/plot_prec_rec.py b/plot_prec_rec.py
old mode 100644
new mode 100755
index 4f5a7923a2cb28ee403d35ad8386c198e59b42c2..22f38f797ae156331408c01fbb871f29c3e9d6f9
--- a/plot_prec_rec.py
+++ b/plot_prec_rec.py
@@ -10,30 +10,43 @@ info = {
     'cassin_vireo': ['cassin vireo', 'arriaga2015bird', 'bird'],
     'black-headed_grosbeaks': ['black-headed grosbeaks', 'arriaga2015bird', 'bird'],
     'zebra_finch': ['zebra finch', 'elie2018zebra', 'bird'],
-    'otter': ['otter', '', ''],
     'humpback': ['humpback whale', 'malige2021use', 'cetacean'],
+    'humpback2': ['humpback whale (small)', 'malige2021use', 'cetacean'],
     'dolphin': ['bottlenose dolphin', 'sayigh2022sarasota', 'cetacean']
 }
 
-out = "\\textbf{Specie and source} & \\textbf{\# labels} & \\textbf{\# clusters} & \\textbf{\# discr. clusters} & \\textbf{\% clustered vocs} & \\textbf{\# missed labels} \\\\ \hline \n"
+#out = "\\textbf{Specie and source} & \\textbf{\# labels} & \\textbf{\# clusters} & \\textbf{\% discr. clusters} & \\textbf{\% clustered vocs} & \\textbf{\# missed labels} \\\\ \hline \n"
+out = ""
 
 for specie in species:
-    dic = np.load(f'{specie}/encodings//encodings_{specie}_256_pcenMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item()
-    idxs, X = dic['idxs'], dic['umap']
+    dic = np.load(f'{specie}/encodings//encodings_{specie}_256_logMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item()
+    idxs, X = dic['idxs'], dic['umap8']
     df = pd.read_csv(f'{specie}/{specie}.csv')
-    clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.05, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
+    clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf' if not 'humpbacjjk' in specie else 'eom').fit_predict(X)
     df.loc[idxs, 'cluster'] = clusters.astype(int)
+
+    dic = np.load(f'{specie}/encodings/encodings_spec32.npy', allow_pickle=True).item()
+    idxs, X = dic['idxs'], dic['umap8']
+    clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf' if not 'humpbacjjk' in specie else 'eom').fit_predict(X)
+    df.loc[idxs, 'cluster2'] = clusters.astype(int)
+
     mask = ~df.loc[idxs].label.isna()
 
     print(specie)
     labelled = df[~df.label.isna()]
-    goodClusters, missedLabels = [], []
+    goodClusters, goodClusters2, missedLabels, missedLabels2 = [], [], [], []
     for l, grp in labelled.groupby('label'):
         precisions = grp.groupby('cluster').fn.count() / labelled.groupby('cluster').fn.count()
         best = precisions.idxmax()
         goodClusters.extend(precisions[precisions > 0.9].index)
         if not (precisions > .9).any():
             missedLabels.append(l)
+
+        precisions = grp.groupby('cluster2').fn.count() / labelled.groupby('cluster2').fn.count()
+        goodClusters2.extend(precisions[precisions > 0.9].index)
+        if not (precisions > .9).any():
+            missedLabels2.append(l)
+
         # print(f'Best precision for {l} is for cluster {best} with {(df.cluster==best).sum()} points, \
         # with precision {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.cluster==best).sum():.2f}\
         # and recall {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.label==l).sum():.2f}')
@@ -43,7 +56,9 @@ for specie in species:
     print(f'{len(goodClusters)/df.label.nunique():.1f} cluster per label in avg)')
     print(f'{len(missedLabels)} over {df.label.nunique()} missed labels')
 
-    out += f"{info[specie][0]} \cite{{{info[specie][1]}}} & {df.label.nunique()} & {df.cluster.nunique()-1} & {len(goodClusters)} & {df.cluster.isin(goodClusters).sum()/len(df)*100:.0f} & {len(missedLabels)} \\\\ \hline \n"
+    out += f"{info[specie][0]} \cite{{{info[specie][1]}}} & {df.label.nunique()} & {df.cluster.nunique()-1} & {df.cluster2.nunique()-1} & {len(goodClusters)/labelled.cluster.nunique()*100:.0f} &"
+    out += f"{len(goodClusters2)/labelled.cluster2.nunique()*100:.0f} & {df.cluster.isin(goodClusters).sum()/len(df[df.cluster.isin(labelled.cluster.unique())])*100:.0f} & "
+    out += f"{df.cluster2.isin(goodClusters2).sum()/len(df[df.cluster2.isin(labelled.cluster2.unique())])*100:.0f} & {len(missedLabels)} & {len(missedLabels2)} \\\\ \hline \n"
 
 f = open('cluster_distrib.tex', 'w')
 f.write(out)
diff --git a/plot_projections.py b/plot_projections.py
old mode 100644
new mode 100755
index 5ed7d55d3505103b2940d29489080ffd2cd6373c..1ad4cc9633b80a539cee7ecea81b36da749f9f0e
--- a/plot_projections.py
+++ b/plot_projections.py
@@ -1,33 +1,38 @@
+from tqdm import tqdm
 import matplotlib.pyplot as plt
 import numpy as np, pandas as pd
 
 
-species = np.loadtxt('good_species.txt', dtype=str)
-
-fig, ax = plt.subplots(ncols=4, nrows=2, figsize=(20, 10))
+fig, ax = plt.subplots(ncols=4, nrows=2, figsize=(13, 7))
 
 non_zero_min = lambda arr: np.min(arr[arr!=0])
 
-for i, specie in enumerate(species):
-    dic = np.load(f'{specie}/encodings//encodings_{specie}_256_pcenMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item()
+for i, specie in tqdm(enumerate(['bengalese_finch1', 'bengalese_finch2', 'california_thrashers', 'cassin_vireo', 'black-headed_grosbeaks', 'humpback', 'humpback2', 'dolphin'])):
+    dic = np.load(f'{specie}/encodings//encodings_{specie}_256_logMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item()
     df = pd.read_csv(f'{specie}/{specie}.csv')
     df.loc[dic['idxs'], 'umap_x'] = dic['umap'][:,0]
     df.loc[dic['idxs'], 'umap_y'] = dic['umap'][:,1]
-    ax[i//4,i%4].scatter(df[df.label.isna()].umap_x, df[df.label.isna()].umap_y, s=2, color='grey')
+    ax[i//4,i%4].scatter(df[df.label.isna()].umap_x, df[df.label.isna()].umap_y, s=1, color='grey')
     for label, grp in df[~df.label.isna()].groupby('label'):
         grp = grp.sample(min(len(grp), 1000))
-        ax[i//4,i%4].scatter(grp.umap_x, grp.umap_y, s=2)
-    ax[i//4,i%4].set_title(specie.replace('_', ' '))
-
+        ax[i//4,i%4].scatter(grp.umap_x, grp.umap_y, s=1)
+    ax[i//4,i%4].set_title(specie.replace('_', ' ') if specie != 'humpback2' else 'humpback (small)')
 
-    sampSize = 100
+    # Hopkins statistic
+    sampSize = len(df)//15
     X = df.sample(sampSize)[['umap_x', 'umap_y']].to_numpy()
     Y = np.vstack([np.random.normal(np.mean(X[:,0]), np.std(X[:,0]), sampSize), np.random.normal(np.mean(X[:,1]), np.std(X[:,1]), sampSize)]).T
     U = np.sum([min(np.sqrt(np.sum((y - df[['umap_x', 'umap_y']].to_numpy())**2, axis=1))) for y in Y])
     W = np.sum([non_zero_min(np.sqrt(np.sum((x - df[['umap_x', 'umap_y']].to_numpy())**2, axis=1))) for x in X])
     hopkins = U / (U + W)
-    ax[i//4, i%4].text(ax[i//4, i%4].get_xlim()[0] + .5, ax[i//4, i%4].get_ylim()[0] + .5, f'{hopkins:.2f}', fontsize=15)
+    ax[i//4, i%4].text(ax[i//4, i%4].get_xlim()[0] + .5, ax[i//4, i%4].get_ylim()[0] + .5, f'{hopkins:.2f}', fontsize=10)
+    ax[i//4, i%4].set_xticks([])
+    ax[i//4, i%4].set_yticks([])
+
+#ax[1,3].set_xticks([])
+#ax[1,3].set_yticks([])
+#ax[1,3].set_frame_on(False)
 
 plt.tight_layout()
 plt.savefig('projections.pdf')
-plt.savefig('projections.png')
\ No newline at end of file
+plt.savefig('projections.png')
diff --git a/plot_results_hdbcsan.py b/plot_results_hdbcsan.py
deleted file mode 100755
index 09cd3417896f539a4f5937a06536967b4a999265..0000000000000000000000000000000000000000
--- a/plot_results_hdbcsan.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import hdbscan
-import pandas as pd
-import matplotlib.pyplot as plt
-import numpy as np
-from sklearn import metrics
-from tqdm import tqdm
-import os
-
-species = np.loadtxt('good_species.txt', dtype=str)
-frontends = ['biosound'] #['vggish', '256_logMel128', '256_logSTFT', '256_Mel128', '32_pcenMel128', '64_pcenMel128', '128_pcenMel128', '256_pcenMel128', '512_pcenMel128']
-
-file = open('hdbscan_HP2.txt', 'w')
-
-plt.figure()
-for specie in ['humpback', 'dolphin', 'black-headed_grosbeaks', 'california_thrashers']: #species:
-    df = pd.read_csv(f'{specie}/{specie}.csv')
-    nmis = []
-    for i, frontend in tqdm(enumerate(frontends), desc=specie, total=len(frontends)):
-        file.write(specie+' '+frontend+'\n')
-        fn = f'{specie}/encodings/encodings_' + (f'{specie}_{frontend}_sparrow_encoder_decod2_BN_nomaxPool.npy' if not frontend in ['vggish', 'biosound'] else frontend+'.npy')
-        if not os.path.isfile(fn):
-            nmis.append(None)
-            print('not found')
-            continue
-        dic = np.load(fn, allow_pickle=True).item()
-        idxs,  X = dic['idxs'], dic['umap']
-
-        # clusters = hdbscan.HDBSCAN(min_cluster_size=max(10, len(df)//100), core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
-        # df.loc[idxs, 'cluster'] = clusters.astype(int)
-        # mask = ~df.loc[idxs].label.isna()
-        # clusters, labels = clusters[mask], df.loc[idxs[mask]].label
-        # nmis.append(metrics.normalized_mutual_info_score(labels, clusters))
-        # df.drop('cluster', axis=1, inplace=True)
-        # continue
-
-        nmi = 0
-        for mcs in [10, 20, 50, 100, 150, 200]:
-            for ms in [None, 3, 5, 10, 20, 30]:
-                for eps in [0.0, 0.01, 0.02, 0.05, 0.1]:
-                    for al in ['leaf', 'eom']:
-                        clusters = hdbscan.HDBSCAN(min_cluster_size=mcs, min_samples=ms, cluster_selection_epsilon=eps, \
-                            core_dist_n_jobs=-1, cluster_selection_method=al).fit_predict(X)
-                        df.loc[idxs, 'cluster'] = clusters.astype(int)
-                        mask = ~df.loc[idxs].label.isna()
-                        clusters, labels = clusters[mask], df.loc[idxs[mask]].label
-                        tnmi = metrics.normalized_mutual_info_score(labels, clusters)
-                        file.write(f'{tnmi} {mcs} {ms} {eps} {al}\n')
-                        if tnmi > nmi :
-                            nmi
-                        df.drop('cluster', axis=1, inplace=True)
-        nmis.append(nmi)
-    plt.scatter(nmis, np.arange(len(frontends)), label=specie)
-
-file.close()
-
-plt.yticks(range(len(frontends)), frontends)
-plt.ylabel('archi')
-plt.xlabel('NMI with expert labels')
-plt.grid()
-plt.tight_layout()
-plt.legend()
-plt.savefig('NMIs_hdbscan.pdf')
-plt.close()
diff --git a/print_annot.py b/print_annot.py
index ac568d4c98416b606cc30055887ac0c510382f41..da77b06dc33611633b0de7e461b9dac729effc71 100755
--- a/print_annot.py
+++ b/print_annot.py
@@ -1,7 +1,7 @@
-import os, argparse
+import os, shutil, argparse
 from tqdm import tqdm
 import matplotlib.pyplot as plt
-import pandas as pd
+import pandas as pd, numpy as np
 import models, utils as u
 import torch
 
@@ -14,19 +14,21 @@ args = parser.parse_args()
 meta = models.meta[args.specie]
 df = pd.read_csv(f'{args.specie}/{args.specie}.csv')
 frontend = models.frontend[args.frontend](meta['sr'], meta['nfft'], meta['sampleDur'], args.nMel)
-os.system(f'rm -R {args.specie}/annot_pngs/*')
+shutil.rmtree(f'{args.specie}/annot_pngs', ignore_errors=True)
+
 for label, grp in df.groupby('label'):
-    os.system(f'mkdir -p "{args.specie}/annot_pngs/{label}"')
+    os.makedirs(f'{args.specie}/annot_pngs/{label}', exist_ok=True)
     loader = torch.utils.data.DataLoader(u.Dataset(grp, args.specie+'/audio/', meta['sr'], meta['sampleDur']),\
                                          batch_size=1, num_workers=4, pin_memory=True)
     for x, idx in tqdm(loader, desc=args.specie + ' ' + label, leave=False):
         x = frontend(x).squeeze().detach()
         assert not torch.isnan(x).any(), "Found a NaN in spectrogram... :/"
         plt.figure()
-        plt.imshow(x, origin='lower', aspect='auto')
+        plt.imshow(x, origin='lower', aspect='auto', cmap='Greys', vmin=np.quantile(x, .7))
+        plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
         row = df.loc[idx.item()]
-        plt.savefig(f'{args.specie}/annot_pngs/{label}/{row.fn.split(".")[0]}_{row.pos:.2f}.png')
-        # plt.savefig(f'{args.specie}/annot_pngs/{label}/{idx.item()}')
+        #plt.savefig(f'{args.specie}/annot_pngs/{label}/{row.fn.split(".")[0]}_{row.pos:.2f}.png')
+        plt.savefig(f'{args.specie}/annot_pngs/{label}/{idx.item()}')
         plt.close()
 
 
diff --git a/print_reconstr.py b/print_reconstr.py
new file mode 100755
index 0000000000000000000000000000000000000000..66efcbf29cf6e88038362fd63b9ed3efd7fb1068
--- /dev/null
+++ b/print_reconstr.py
@@ -0,0 +1,48 @@
+import os, shutil, argparse
+from tqdm import tqdm
+import matplotlib.pyplot as plt
+import pandas as pd, numpy as np
+import models, utils as u
+import torch
+
+parser = argparse.ArgumentParser()
+parser.add_argument("specie", type=str)
+parser.add_argument("-bottleneck", type=int, default=16)
+parser.add_argument("-frontend", type=str, default='logMel')
+parser.add_argument("-encoder", type=str, default='sparrow_encoder')
+parser.add_argument("-prcptl", type=int, default=1)
+parser.add_argument("-nMel", type=int, default=128)
+args = parser.parse_args()
+
+modelname = f'{args.specie}_{args.bottleneck}_{args.frontend}{args.nMel if "Mel" in args.frontend else ""}_{args.encoder}_decod2_BN_nomaxPool{"_noprcptl" if args.prcptl==0 else ""}.stdc'
+print(modelname)
+gpu = torch.device(f'cuda')
+
+meta = models.meta[args.specie]
+frontend = models.frontend[args.frontend](meta['sr'], meta['nfft'], meta['sampleDur'], args.nMel)
+encoder = models.__dict__[args.encoder](*((args.bottleneck // 16, (4, 4)) if args.nMel == 128 else (args.bottleneck // 8, (2, 4))))
+decoder = models.sparrow_decoder(args.bottleneck, (4, 4) if args.nMel == 128 else (2, 4))
+model = torch.nn.Sequential(frontend, encoder, decoder).to(gpu)
+model.load_state_dict(torch.load(f'{args.specie}/weights/{modelname}'))
+model.eval()
+
+df = pd.read_csv(f'{args.specie}/{args.specie}.csv')
+
+shutil.rmtree(f'{args.specie}/reconstruct_pngs', ignore_errors=True)
+
+for label, grp in df.groupby('label'):
+    os.makedirs(f'{args.specie}/reconstruct_pngs/{label}', exist_ok=True)
+    loader = torch.utils.data.DataLoader(u.Dataset(grp, args.specie+'/audio/', meta['sr'], meta['sampleDur']),\
+                                         batch_size=1, num_workers=4, pin_memory=True)
+    with torch.no_grad():
+        for x, idx in tqdm(loader, desc=args.specie + ' ' + label, leave=False):
+            x = model(x.to(gpu)).squeeze().detach().cpu()
+            assert not torch.isnan(x).any(), "Found a NaN in spectrogram... :/"
+            plt.imshow(x, origin='lower', aspect='auto') #, cmap='Greys', vmin=np.quantile(x, .7))
+            plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
+            row = df.loc[idx.item()]
+            plt.savefig(f'{args.specie}/reconstruct_pngs/{label}/{idx.item()}')
+            plt.close()
+
+
+
diff --git a/run_baseline.py b/run_baseline.py
index ca30c1c7d7e39304ae29ec739c81da15f3a7de48..dcea284b6418d546cb510c6c99c18046cc66a5d1 100755
--- a/run_baseline.py
+++ b/run_baseline.py
@@ -43,7 +43,7 @@ def process(idx):
     
     return [sound.__dict__[f] for f in feats]
 
-res = p_tqdm.p_map(process, df.index, num_cpus=14)
+res = p_tqdm.p_map(process, df.index, num_cpus=10)
 
 for i, mr in zip(df.index, res):
     for f, r in zip(feats, mr):
diff --git a/run_hearbaseline.py b/run_hearbaseline.py
index e5b9a3f2a8711941af8dec04f57259638f6fa897..52ba7e6af44ca3d962cec2b99646b814aa753dcb 100755
--- a/run_hearbaseline.py
+++ b/run_hearbaseline.py
@@ -1,11 +1,9 @@
 from sklearn import metrics
-import matplotlib.pyplot as plt
-#import umap, hdbscan
 from tqdm import tqdm
 import argparse, os
 import models, utils as u
 import pandas as pd, numpy as np, torch
-from hearbaseline import wav2vec2 as hear
+from hearbaseline import wav2vec2, torchopenl3, torchcrepe, vggish
 
 parser = argparse.ArgumentParser()
 parser.add_argument("specie", type=str)
@@ -17,57 +15,17 @@ df = pd.read_csv(f'{args.specie}/{args.specie}.csv')
 meta = models.meta[args.specie]
 batch_size = 32
 
-if True : #not os.path.isfile(f'{args.specie}/encodings/encodings_wave2vec2.npy'):
-    gpu = torch.device(f'cuda:{args.cuda}')
-    model = hear.load_model().to(gpu)
+
+gpu = torch.device(f'cuda:{args.cuda}')
+for module, name in zip([torchopenl3, wav2vec2, torchcrepe, vggish], ['openl3', 'wav2vec2', 'crepe', 'vggish']):
+    model = module.load_model().to(gpu)
     loader = torch.utils.data.DataLoader(u.Dataset(df, f'{args.specie}/audio/', model.sample_rate, meta['sampleDur']), batch_size=batch_size, num_workers=8, collate_fn=u.collate_fn)
     with torch.inference_mode():
         encodings, idxs = [], []
         for x, idx in tqdm(loader, desc='test '+args.specie, leave=False):
-            encoding = hear.get_scene_embeddings(x.to(gpu), model=model)
+            encoding = module.get_scene_embeddings(x.to(gpu), model=model)
             idxs.extend(idx.numpy())
             encodings.extend(encoding.view(len(x), -1).cpu().numpy())
     idxs, encodings = np.array(idxs), np.stack(encodings)
-#    X = umap.UMAP(n_jobs=-1, n_components=8).fit_transform(encodings)
-    np.save(f'{args.specie}/encodings/encodings_wav2vec2.npy', {'idxs':idxs, 'encodings':encodings}) #, 'umap8':X})
-    exit()
-else:
-    dic = np.load(f'{args.specie}/encodings/encodings_wave2vec2.npy', allow_pickle=True).item()
-    idxs, encodings, X = dic['idxs'], dic['encodings'], dic['umap8']
-
-clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
-df.loc[idxs, 'cluster'] = clusters.astype(int)
-mask = ~df.loc[idxs].label.isna()
-
-clusters, labels = clusters[mask], df.loc[idxs[mask]].label
-print('NMI', metrics.normalized_mutual_info_score(labels, clusters))
-exit()
-#print('Found clusters : \n', pd.Series(clusters).value_counts())
-
-plt.figure(figsize=(20, 10))
-plt.scatter(X[clusters==-1,0], X[clusters==-1,1], s=2, alpha=.2, color='Grey')
-plt.scatter(X[clusters!=-1,0], X[clusters!=-1,1], s=2, c=clusters[clusters!=-1], cmap='tab20')
-plt.tight_layout()
-plt.savefig(f'{args.specie}/projections/vggish_projection_clusters.png')
-
-plt.figure(figsize=(20, 10))
-plt.scatter(X[~mask,0], X[~mask,1], s=2, alpha=.2, color='Grey')
-for l, grp in df.groupby('label'):
-    plt.scatter(X[df.loc[idxs].label==l, 0], X[df.loc[idxs].label==l, 1], s=4, label=l)
-plt.legend()
-plt.tight_layout()
-plt.savefig(f'{args.specie}/projections/vggish_projection_labels.png')
-
-
-clusters, labels = clusters[mask], df.loc[idxs[mask]].label
-print('Silhouette', metrics.silhouette_score(encodings[mask], clusters))
-print('NMI', metrics.normalized_mutual_info_score(labels, clusters))
-print('Homogeneity', metrics.homogeneity_score(labels, clusters))
-print('Completeness', metrics.completeness_score(labels, clusters))
-print('V-Measure', metrics.v_measure_score(labels, clusters))
-
-labelled = df[~df.label.isna()]
-for l, grp in labelled.groupby('label'):
-    best = (grp.groupby('cluster').fn.count() / labelled.groupby('cluster').fn.count()).idxmax()
-    print(f'Best precision for {l} is for cluster {best} with {(df.cluster==best).sum()} points, \
-with precision {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.cluster==best).sum():.2f} and recall {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.label==l).sum():.2f}')
+    #    X = umap.UMAP(n_jobs=-1, n_components=8).fit_transform(encodings)
+    np.save(f'{args.specie}/encodings/encodings_{name}.npy', {'idxs':idxs, 'encodings':encodings}) #, 'umap8':X})
diff --git a/run_openl3.py b/run_openl3.py
deleted file mode 100755
index aefb97664c6f9093c1c250a24439361f10937e7f..0000000000000000000000000000000000000000
--- a/run_openl3.py
+++ /dev/null
@@ -1,74 +0,0 @@
-from sklearn import metrics
-import matplotlib.pyplot as plt
-import umap, hdbscan
-from tqdm import tqdm
-import argparse, os
-import models, utils as u
-import pandas as pd, numpy as np, torch
-import torchopenl3 as openl3
-
-parser = argparse.ArgumentParser()
-parser.add_argument("specie", type=str)
-parser.add_argument("-cuda", type=int, default=0)
-args = parser.parse_args()
-
-df = pd.read_csv(f'{args.specie}/{args.specie}.csv')
-
-meta = models.meta[args.specie]
-batch_size = 64
-
-if True : #not os.path.isfile(f'{args.specie}/encodings/encodings_openl3.npy'):
-    gpu = torch.device(f'cuda:{args.cuda}')
-    frontend = models.frontend['logMel_vggish'](meta['sr'], meta['nfft'], meta['sampleDur'], 64)
-    loader = torch.utils.data.DataLoader(u.Dataset(df, f'{args.specie}/audio/', meta['sr'], meta['sampleDur']), batch_size=batch_size, num_workers=8, collate_fn=u.collate_fn)
-    model = openl3.models.load_audio_embedding_model(input_repr="mel128", content_type="music", embedding_size=512).to(gpu)
-    with torch.no_grad():
-        encodings, idxs = [], []
-        for x, idx in tqdm(loader, desc='test '+args.specie, leave=False):
-            encoding = openl3.get_audio_embedding(x.to(gpu), meta['sr'], model=model, center=False, batch_size=batch_size, verbose=False)[0]
-            idxs.extend(idx.numpy())
-            encodings.extend(encoding.mean(axis=1).cpu().numpy())
-
-    idxs, encodings = np.array(idxs), np.stack(encodings)
-    X = umap.UMAP(n_jobs=-1, n_components=8).fit_transform(encodings)
-    np.save(f'{args.specie}/encodings/encodings_openl3.npy', {'idxs':idxs, 'encodings':encodings, 'umap8':X})
-else:
-    dic = np.load(f'{args.specie}/encodings/encodings_openl3.npy', allow_pickle=True).item()
-    idxs, encodings, X = dic['idxs'], dic['encodings'], dic['umap8']
-
-clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
-df.loc[idxs, 'cluster'] = clusters.astype(int)
-mask = ~df.loc[idxs].label.isna()
-
-clusters, labels = clusters[mask], df.loc[idxs[mask]].label
-print('NMI', metrics.normalized_mutual_info_score(labels, clusters))
-exit()
-#print('Found clusters : \n', pd.Series(clusters).value_counts())
-
-plt.figure(figsize=(20, 10))
-plt.scatter(X[clusters==-1,0], X[clusters==-1,1], s=2, alpha=.2, color='Grey')
-plt.scatter(X[clusters!=-1,0], X[clusters!=-1,1], s=2, c=clusters[clusters!=-1], cmap='tab20')
-plt.tight_layout()
-plt.savefig(f'{args.specie}/projections/vggish_projection_clusters.png')
-
-plt.figure(figsize=(20, 10))
-plt.scatter(X[~mask,0], X[~mask,1], s=2, alpha=.2, color='Grey')
-for l, grp in df.groupby('label'):
-    plt.scatter(X[df.loc[idxs].label==l, 0], X[df.loc[idxs].label==l, 1], s=4, label=l)
-plt.legend()
-plt.tight_layout()
-plt.savefig(f'{args.specie}/projections/vggish_projection_labels.png')
-
-
-clusters, labels = clusters[mask], df.loc[idxs[mask]].label
-print('Silhouette', metrics.silhouette_score(encodings[mask], clusters))
-print('NMI', metrics.normalized_mutual_info_score(labels, clusters))
-print('Homogeneity', metrics.homogeneity_score(labels, clusters))
-print('Completeness', metrics.completeness_score(labels, clusters))
-print('V-Measure', metrics.v_measure_score(labels, clusters))
-
-labelled = df[~df.label.isna()]
-for l, grp in labelled.groupby('label'):
-    best = (grp.groupby('cluster').fn.count() / labelled.groupby('cluster').fn.count()).idxmax()
-    print(f'Best precision for {l} is for cluster {best} with {(df.cluster==best).sum()} points, \
-with precision {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.cluster==best).sum():.2f} and recall {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.label==l).sum():.2f}')
diff --git a/run_vggish.py b/run_vggish.py
deleted file mode 100755
index c7bcacc8c5dac05650c2288616b3914e174e74f8..0000000000000000000000000000000000000000
--- a/run_vggish.py
+++ /dev/null
@@ -1,74 +0,0 @@
-from sklearn import metrics
-import matplotlib.pyplot as plt
-import umap, hdbscan
-from tqdm import tqdm
-import argparse, os
-import models, utils as u
-import pandas as pd, numpy as np, torch
-torch.multiprocessing.set_sharing_strategy('file_system')
-
-parser = argparse.ArgumentParser()
-parser.add_argument("specie", type=str)
-args = parser.parse_args()
-
-df = pd.read_csv(f'{args.specie}/{args.specie}.csv')
-
-meta = models.meta[args.specie]
-
-if not os.path.isfile(f'{args.specie}/encodings/encodings_vggish.npy'):
-    gpu = torch.device('cuda')
-    frontend = models.frontend['logMel_vggish'](meta['sr'], meta['nfft'], meta['sampleDur'], 64)
-    vggish = torch.hub.load('harritaylor/torchvggish', 'vggish')
-    # vggish.preprocess = False
-    vggish.postprocess = False
-    model = torch.nn.Sequential(frontend, vggish).to(gpu)
-    model.eval()
-    loader = torch.utils.data.DataLoader(u.Dataset(df, f'{args.specie}/audio/', 16000, 1), batch_size=1, shuffle=True, num_workers=8, collate_fn=u.collate_fn)
-    with torch.no_grad():
-        encodings, idxs = [], []
-        for x, idx in tqdm(loader, desc='test '+args.specie, leave=False):
-            # encoding = model(x.to(gpu))
-            encoding = vggish(x.numpy().squeeze(0), fs=16000)
-            idxs.extend(idx.numpy())
-            encodings.extend(encoding.cpu().numpy())
-
-    idxs, encodings = np.array(idxs), np.stack(encodings)
-    X = umap.UMAP(n_jobs=-1).fit_transform(encodings)
-    np.save(f'{args.specie}/encodings/encodings_vggish.npy', {'idxs':idxs, 'encodings':encodings, 'umap':X})
-else:
-    dic = np.load(f'{args.specie}/encodings/encodings_vggish.npy', allow_pickle=True).item()
-    idxs, encodings, X = dic['idxs'], dic['encodings'], dic['umap']
-
-clusters = hdbscan.HDBSCAN(min_cluster_size=50, min_samples=5, cluster_selection_epsilon=0.05, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
-df.loc[idxs, 'cluster'] = clusters.astype(int)
-mask = ~df.loc[idxs].label.isna()
-
-#print('Found clusters : \n', pd.Series(clusters).value_counts())
-
-plt.figure(figsize=(20, 10))
-plt.scatter(X[clusters==-1,0], X[clusters==-1,1], s=2, alpha=.2, color='Grey')
-plt.scatter(X[clusters!=-1,0], X[clusters!=-1,1], s=2, c=clusters[clusters!=-1], cmap='tab20')
-plt.tight_layout()
-plt.savefig(f'{args.specie}/projections/vggish_projection_clusters.png')
-
-plt.figure(figsize=(20, 10))
-plt.scatter(X[~mask,0], X[~mask,1], s=2, alpha=.2, color='Grey')
-for l, grp in df.groupby('label'):
-    plt.scatter(X[df.loc[idxs].label==l, 0], X[df.loc[idxs].label==l, 1], s=4, label=l)
-plt.legend()
-plt.tight_layout()
-plt.savefig(f'{args.specie}/projections/vggish_projection_labels.png')
-
-
-clusters, labels = clusters[mask], df.loc[idxs[mask]].label
-print('Silhouette', metrics.silhouette_score(encodings[mask], clusters))
-print('NMI', metrics.normalized_mutual_info_score(labels, clusters))
-print('Homogeneity', metrics.homogeneity_score(labels, clusters))
-print('Completeness', metrics.completeness_score(labels, clusters))
-print('V-Measure', metrics.v_measure_score(labels, clusters))
-
-labelled = df[~df.label.isna()]
-for l, grp in labelled.groupby('label'):
-    best = (grp.groupby('cluster').fn.count() / labelled.groupby('cluster').fn.count()).idxmax()
-    print(f'Best precision for {l} is for cluster {best} with {(df.cluster==best).sum()} points, \
-with precision {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.cluster==best).sum():.2f} and recall {((labelled.cluster==best)&(labelled.label==l)).sum()/(labelled.label==l).sum():.2f}')
diff --git a/test_AE.py b/test_AE.py
index 5fcc0ae21a012e2964529e1b24c9dee80d341dde..92f8100885ee4e7c624cf6109de493a2f7446349 100755
--- a/test_AE.py
+++ b/test_AE.py
@@ -56,7 +56,7 @@ else:
 if X is None:
     X = umap.UMAP(n_jobs=-1, n_components=8).fit_transform(encodings)
     print('/!\ no UMAP')
-clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=.1, cluster_selection_method='leaf', core_dist_n_jobs=-1).fit_predict(X)
+clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=.1, cluster_selection_method='leaf' if not 'humpback' in args.specie else 'eom', core_dist_n_jobs=-1).fit_predict(X)
 df.loc[idxs, 'cluster'] = clusters.astype(int)
 mask = ~df.loc[idxs].label.isna()
 
diff --git a/train_AE.py b/train_AE.py
index b13dca3027529a4a9ade28cc6a007fe492cefcd2..d02907e84b44197240b7173284d8dc6ff5eb86a4 100755
--- a/train_AE.py
+++ b/train_AE.py
@@ -1,10 +1,7 @@
-import umap, hdbscan
 from torchvision.utils import make_grid
 from torch.utils.tensorboard import SummaryWriter
 import torch
-from sklearn import metrics, cluster
 import numpy as np, pandas as pd
-from scipy.stats import linregress
 import utils as u, models
 from tqdm import tqdm
 import os, argparse, warnings
@@ -13,9 +10,10 @@ warnings.filterwarnings("error")
 
 parser = argparse.ArgumentParser()
 parser.add_argument("specie", type=str)
-parser.add_argument("-bottleneck", type=int, default=16)
+parser.add_argument("-bottleneck", type=int, default=256)
 parser.add_argument("-frontend", type=str, default='logMel')
 parser.add_argument("-encoder", type=str, default='sparrow_encoder')
+parser.add_argument("-prcptl", type=int, default=1)
 parser.add_argument("-nMel", type=int, default=128)
 parser.add_argument("-lr", type=float, default=3e-3)
 parser.add_argument("-lr_decay", type=float, default=1e-2)
@@ -26,10 +24,10 @@ args = parser.parse_args()
 df = pd.read_csv(f'{args.specie}/{args.specie}.csv')
 print(f'{len(df)} available vocs')
 
-modelname = f'{args.specie}_{args.bottleneck}_{args.frontend}{args.nMel if "Mel" in args.frontend else ""}_{args.encoder}_decod2_BN_nomaxPool.stdc'
+modelname = f'{args.specie}_{args.bottleneck}_{args.frontend}{args.nMel if "Mel" in args.frontend else ""}_{args.encoder}_decod2_BN_nomaxPool{"_noprcptl" if args.prcptl==0 else ""}.stdc'
 gpu = torch.device(f'cuda:{args.cuda}')
-writer = SummaryWriter(f'runs2/{modelname}')
-os.system(f'cp *.py runs2/{modelname}')
+writer = SummaryWriter(f'runs/{modelname}')
+os.system(f'cp *.py runs/{modelname}')
 vgg16 = models.vgg16
 vgg16.eval().to(gpu)
 meta = models.meta[args.specie]
@@ -56,8 +54,11 @@ for epoch in range(100_000//len(loader)):
         x = encoder(label)
         pred = decoder(x)
         assert not torch.isnan(pred).any(), "found a NaN :'("
-        predd = vgg16(pred.expand(pred.shape[0], 3, *pred.shape[2:]))
-        labell = vgg16(label.expand(label.shape[0], 3, *label.shape[2:]))
+        if args.prcptl == 1:
+            predd = vgg16(pred.expand(pred.shape[0], 3, *pred.shape[2:]))
+            labell = vgg16(label.expand(label.shape[0], 3, *label.shape[2:]))
+        else:
+            predd, labell = pred, label
 
         score = MSE(predd, labell)
         score.backward()
@@ -69,10 +70,14 @@ for epoch in range(100_000//len(loader)):
             print('Early stop')
             torch.save(model.state_dict(), f'{args.specie}/weights/{modelname}')
             exit()
+
         step += 1
-        continue
-        # TEST ROUTINE
-        if step % 500 == 0:
+
+        if step % 500 == 0: # scheduler
+            scheduler.step()
+
+        # Plot images
+        if step % 100 == 0 :
             # Plot reconstructions
             images = [(e-e.min())/(e.max()-e.min()) for e in label[:8]]
             grid = make_grid(images)
@@ -80,81 +85,4 @@ for epoch in range(100_000//len(loader)):
             # writer.add_embedding(x.detach(), global_step=step, label_img=label)
             images = [(e-e.min())/(e.max()-e.min()) for e in pred[:8]]
             grid = make_grid(images)
-            writer.add_image('reconstruct', grid, step)
-
-            torch.save(model.state_dict(), f'{args.specie}/{modelname}')
-            scheduler.step()
-
-            # Actual test
-            model[1:].eval()
-            with torch.no_grad():
-                encodings, idxs = [], []
-                for x, idx in tqdm(loader, desc='test '+str(step), leave=False):
-                    encoding = model[:2](x.to(gpu))
-                    idxs.extend(idx)
-                    encodings.extend(encoding.cpu().detach())
-            idxs, encodings = np.array(idxs), np.stack(encodings)
-            print('Computing UMAP...', end='')
-            try:
-                X = umap.UMAP(n_jobs=-1).fit_transform(encodings)
-            except UserWarning:
-                pass
-            print('\rRunning HDBSCAN...', end='')
-            clusters = hdbscan.HDBSCAN(min_cluster_size=len(df)//100, min_samples=5, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
-            # df.loc[idxs, 'cluster'] = clusters.astype(int)
-            mask = ~df.loc[idxs].label.isna()
-            clusters, labels = clusters[mask], df.loc[idxs[mask]].label
-            writer.add_scalar('NMI HDBSCAN', metrics.normalized_mutual_info_score(labels, clusters), step)
-            try:
-                writer.add_scalar('ARI HDBSCAN', metrics.adjusted_rand_score(labels, clusters), step)
-            except:
-                pass
-            writer.add_scalar('Homogeneity HDBSCAN', metrics.homogeneity_score(labels, clusters), step)
-            writer.add_scalar('Completeness HDBSCAN', metrics.completeness_score(labels, clusters), step)
-            writer.add_scalar('V-Measure HDBSCAN', metrics.v_measure_score(labels, clusters), step)
-            
-            # print('\rComputing HDBSCAN precision and recall distributions', end='')
-            # labelled = df[~df.label.isna()]
-            # precs, recs = [], []
-            # for l, grp in labelled.groupby('label'):
-            #     best = (grp.groupby('cluster').fn.count() / labelled.groupby('cluster').fn.count()).idxmax()
-            #     precs.append((grp.cluster==best).sum()/(labelled.cluster==best).sum())
-            #     recs.append((grp.cluster==best).sum()/len(grp))
-            # writer.add_histogram('HDBSCAN Precisions ', np.array(precs), step)
-            # writer.add_histogram('HDBSCAN Recalls ', np.array(recs), step)
-            # df.drop('cluster', axis=1, inplace=True)
-            # print('\rRunning elbow method for K-Means...', end='')
-            # ks = (5*1.2**np.arange(20)).astype(int)
-            # distorsions = [cluster.KMeans(n_clusters=k).fit(encodings).inertia_ for k in ks]
-            # print('\rEstimating elbow...', end='')
-            # errors = [linregress(ks[:i], distorsions[:i]).stderr + linregress(ks[i+1:], distorsions[i+1:]).stderr for i in range(2, len(ks)-2)]
-            # k = ks[np.argmin(errors)]
-            # writer.add_scalar('Chosen K', k, step)
-            # clusters = cluster.KMeans(n_clusters=k).fit_predict(encodings)
-            # df.loc[idxs, 'cluster'] = clusters.astype(int)
-
-            # writer.add_scalar('Silhouette', metrics.silhouette_score(encodings, clusters), step)
-            # clusters, labels = clusters[mask], df.loc[idxs[mask]].label
-            # writer.add_scalar('NMI K-Means', metrics.normalized_mutual_info_score(labels, clusters), step)
-            # try:
-            #     writer.add_scalar('ARI K-Means', metrics.adjusted_rand_score(labels, clusters), step)
-            # except:
-            #     pass
-            # writer.add_scalar('Homogeneity K-Means', metrics.homogeneity_score(labels, clusters), step)
-            # writer.add_scalar('Completeness K-Means', metrics.completeness_score(labels, clusters), step)
-            # writer.add_scalar('V-Measure K-Means', metrics.v_measure_score(labels, clusters), step)
-
-            # print('\rComputing K-Means precision and recall distributions', end='') 
-            # labelled = df[~df.label.isna()]
-            # precs, recs = [], []
-            # for l, grp in labelled.groupby('label'):
-            #     best = (grp.groupby('cluster').fn.count() / labelled.groupby('cluster').fn.count()).idxmax()
-            #     precs.append((grp.cluster==best).sum()/(labelled.cluster==best).sum())
-            #     recs.append((grp.cluster==best).sum()/len(grp))
-            # writer.add_histogram('K-Means Precisions ', np.array(precs), step)
-            # writer.add_histogram('K-Means Recalls ', np.array(recs), step)
-            # df.drop('cluster', axis=1, inplace=True)
-
-            print('\r', end='')
-            model[1:].train()
-        step += 1
+            writer.add_image('reconstruct', grid, step)
\ No newline at end of file