diff --git a/.gitignore b/.gitignore
index 76ccee7b16358c65b6c72e2d84db0c85d666f015..3049b72a7212aba0d198157d05cd57ae56780eae 100755
--- a/.gitignore
+++ b/.gitignore
@@ -1,17 +1,18 @@
 *.png
 *.stdc
 *.npy
-*/audio
+paper_experiments/*/audio
 */TextGrid
 *__pycache__
 *log
-humpback2/annots
+paper_experiments/humpback2/annots
 gibbon
 new_specie/*/
 otter/pone.0112562.s003.xlsx
 zebra_finch/Library_notes.pdf
 annot_distrib.pdf
 annot_distrib.tex
-humpback/annot
+paper_experiments/humpback/annot
 humpback_CARIMAM/
-dolphin/zips
+paper_experiments/dolphin/zips
+archive/*
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e3f14220ccbd6964cf06f183f87fd708d9ea595a
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,240 @@
+name: hear
+channels:
+  - pytorch-nightly
+  - nvidia
+  - defaults
+dependencies:
+  - _libgcc_mutex=0.1=main
+  - _openmp_mutex=5.1=1_gnu
+  - asttokens=2.0.5=pyhd3eb1b0_0
+  - backcall=0.2.0=pyhd3eb1b0_0
+  - blas=1.0=mkl
+  - bottleneck=1.3.5=py38h7deecbd_0
+  - brotli=1.0.9=h5eee18b_7
+  - brotli-bin=1.0.9=h5eee18b_7
+  - bzip2=1.0.8=h7b6447c_0
+  - ca-certificates=2023.01.10=h06a4308_0
+  - certifi=2023.5.7=py38h06a4308_0
+  - cffi=1.15.1=py38h5eee18b_3
+  - charset-normalizer=2.0.4=pyhd3eb1b0_0
+  - contourpy=1.0.5=py38hdb19cb5_0
+  - cryptography=39.0.1=py38h9ce1e76_0
+  - cuda-cudart=12.1.105=0
+  - cuda-cupti=12.1.105=0
+  - cuda-libraries=12.1.0=0
+  - cuda-nvrtc=12.1.105=0
+  - cuda-nvtx=12.1.105=0
+  - cuda-opencl=12.1.105=0
+  - cuda-runtime=12.1.0=0
+  - cycler=0.11.0=pyhd3eb1b0_0
+  - dbus=1.13.18=hb2f20db_0
+  - decorator=5.1.1=pyhd3eb1b0_0
+  - executing=0.8.3=pyhd3eb1b0_0
+  - expat=2.4.9=h6a678d5_0
+  - ffmpeg=4.2.2=h20bf706_0
+  - filelock=3.9.0=py38h06a4308_0
+  - fontconfig=2.14.1=h4c34cd2_2
+  - fonttools=4.25.0=pyhd3eb1b0_0
+  - freetype=2.12.1=h4a9f257_0
+  - giflib=5.2.1=h5eee18b_3
+  - glib=2.69.1=he621ea3_2
+  - gmp=6.2.1=h295c915_3
+  - gmpy2=2.1.2=py38heeb90bb_0
+  - gnutls=3.6.15=he1e5248_0
+  - gst-plugins-base=1.14.1=h6a678d5_1
+  - gstreamer=1.14.1=h5eee18b_1
+  - icu=58.2=he6710b0_3
+  - idna=3.4=py38h06a4308_0
+  - importlib_metadata=6.0.0=hd3eb1b0_0
+  - importlib_resources=5.2.0=pyhd3eb1b0_1
+  - intel-openmp=2023.1.0=hdb19cb5_46305
+  - ipython=8.12.0=py38h06a4308_0
+  - jedi=0.18.1=py38h06a4308_1
+  - jinja2=3.1.2=py38h06a4308_0
+  - jpeg=9e=h5eee18b_1
+  - kiwisolver=1.4.4=py38h6a678d5_0
+  - krb5=1.19.4=h568e23c_0
+  - lame=3.100=h7b6447c_0
+  - lcms2=2.12=h3be6417_0
+  - ld_impl_linux-64=2.38=h1181459_1
+  - lerc=3.0=h295c915_0
+  - libbrotlicommon=1.0.9=h5eee18b_7
+  - libbrotlidec=1.0.9=h5eee18b_7
+  - libbrotlienc=1.0.9=h5eee18b_7
+  - libclang13=14.0.6=default_he11475f_1
+  - libcublas=12.1.0.26=0
+  - libcufft=11.0.2.4=0
+  - libcufile=1.6.1.9=0
+  - libcurand=10.3.2.106=0
+  - libcusolver=11.4.4.55=0
+  - libcusparse=12.0.2.55=0
+  - libdeflate=1.17=h5eee18b_0
+  - libedit=3.1.20221030=h5eee18b_0
+  - libevent=2.1.12=h8f2d780_0
+  - libffi=3.4.4=h6a678d5_0
+  - libgcc-ng=11.2.0=h1234567_1
+  - libgfortran-ng=7.5.0=ha8ba4b0_17
+  - libgfortran4=7.5.0=ha8ba4b0_17
+  - libgomp=11.2.0=h1234567_1
+  - libidn2=2.3.4=h5eee18b_0
+  - libllvm14=14.0.6=hdb19cb5_3
+  - libnpp=12.0.2.50=0
+  - libnvjitlink=12.1.105=0
+  - libnvjpeg=12.1.0.39=0
+  - libopus=1.3.1=h7b6447c_0
+  - libpng=1.6.39=h5eee18b_0
+  - libpq=12.9=h16c4e8d_3
+  - libstdcxx-ng=11.2.0=h1234567_1
+  - libtasn1=4.19.0=h5eee18b_0
+  - libtiff=4.5.0=h6a678d5_2
+  - libunistring=0.9.10=h27cfd23_0
+  - libuuid=1.41.5=h5eee18b_0
+  - libvpx=1.7.0=h439df22_0
+  - libwebp=1.2.4=h11a3e52_1
+  - libwebp-base=1.2.4=h5eee18b_1
+  - libxcb=1.15=h7f8727e_0
+  - libxkbcommon=1.0.1=h5eee18b_1
+  - libxml2=2.10.3=hcbfbd50_0
+  - libxslt=1.1.37=h2085143_0
+  - lz4-c=1.9.4=h6a678d5_0
+  - markupsafe=2.1.1=py38h7f8727e_0
+  - matplotlib=3.7.1=py38h06a4308_1
+  - matplotlib-base=3.7.1=py38h417a72b_1
+  - matplotlib-inline=0.1.6=py38h06a4308_0
+  - mkl=2023.1.0=h6d00ec8_46342
+  - mkl-service=2.4.0=py38h5eee18b_1
+  - mkl_fft=1.3.6=py38h417a72b_1
+  - mkl_random=1.2.2=py38h417a72b_1
+  - mpc=1.1.0=h10f8cd9_1
+  - mpfr=4.0.2=hb69a4c5_1
+  - mpmath=1.2.1=py38h06a4308_0
+  - munkres=1.1.4=py_0
+  - ncurses=6.4=h6a678d5_0
+  - nettle=3.7.3=hbbd107a_1
+  - networkx=2.8.4=py38h06a4308_1
+  - nspr=4.33=h295c915_0
+  - nss=3.74=h0370c37_0
+  - numexpr=2.8.4=py38hc78ab66_1
+  - numpy-base=1.24.3=py38h060ed82_1
+  - openh264=2.1.1=h4ff587b_0
+  - openssl=1.1.1t=h7f8727e_0
+  - pandas=1.5.3=py38h417a72b_0
+  - parso=0.8.3=pyhd3eb1b0_0
+  - pcre=8.45=h295c915_0
+  - pexpect=4.8.0=pyhd3eb1b0_3
+  - pickleshare=0.7.5=pyhd3eb1b0_1003
+  - pillow=9.4.0=py38h6a678d5_0
+  - pip=23.0.1=py38h06a4308_0
+  - ply=3.11=py38_0
+  - prompt-toolkit=3.0.36=py38h06a4308_0
+  - ptyprocess=0.7.0=pyhd3eb1b0_2
+  - pure_eval=0.2.2=pyhd3eb1b0_0
+  - pycparser=2.21=pyhd3eb1b0_0
+  - pygments=2.15.1=py38h06a4308_1
+  - pyopenssl=23.0.0=py38h06a4308_0
+  - pyparsing=3.0.9=py38h06a4308_0
+  - pyqt=5.15.7=py38h6a678d5_1
+  - pyqt5-sip=12.11.0=py38h6a678d5_1
+  - pysocks=1.7.1=py38h06a4308_0
+  - python=3.8.16=h7a1cb2a_3
+  - python-dateutil=2.8.2=pyhd3eb1b0_0
+  - pytorch=2.1.0.dev20230522=py3.8_cuda12.1_cudnn8.8.1_0
+  - pytorch-cuda=12.1=ha16c6d3_5
+  - pytorch-mutex=1.0=cuda
+  - pytz=2022.7=py38h06a4308_0
+  - pyyaml=6.0=py38h5eee18b_1
+  - qt-main=5.15.2=h8373d8f_8
+  - qt-webengine=5.15.9=hbbf29b9_6
+  - qtwebkit=5.212=h3fafdc1_5
+  - readline=8.2=h5eee18b_0
+  - requests=2.29.0=py38h06a4308_0
+  - setuptools=66.0.0=py38h06a4308_0
+  - sip=6.6.2=py38h6a678d5_0
+  - six=1.16.0=pyhd3eb1b0_1
+  - sqlite=3.41.2=h5eee18b_0
+  - stack_data=0.2.0=pyhd3eb1b0_0
+  - sympy=1.11.1=py38h06a4308_0
+  - tbb=2021.8.0=hdb19cb5_0
+  - tk=8.6.12=h1ccaba5_0
+  - toml=0.10.2=pyhd3eb1b0_0
+  - torchaudio=2.1.0.dev20230522=py38_cu121
+  - torchtriton=2.1.0+7d1a95b046=py38
+  - torchvision=0.16.0.dev20230522=py38_cu121
+  - tornado=6.2=py38h5eee18b_0
+  - traitlets=5.7.1=py38h06a4308_0
+  - typing_extensions=4.5.0=py38h06a4308_0
+  - urllib3=1.25.8=py38_0
+  - wcwidth=0.2.5=pyhd3eb1b0_0
+  - wheel=0.38.4=py38h06a4308_0
+  - x264=1!157.20191217=h7b6447c_0
+  - xz=5.4.2=h5eee18b_0
+  - yaml=0.2.5=h7b6447c_0
+  - zlib=1.2.13=h5eee18b_0
+  - zstd=1.5.5=hc292b87_0
+  - pip:
+      - absl-py==1.4.0
+      - astunparse==1.6.3
+      - audioread==3.0.0
+      - cachetools==5.3.0
+      - cython==0.29.34
+      - flatbuffers==2.0.7
+      - fsspec==2023.5.0
+      - gast==0.4.0
+      - google-auth==2.18.1
+      - google-auth-oauthlib==1.0.0
+      - google-pasta==0.2.0
+      - grpcio==1.54.2
+      - h5py==3.8.0
+      - hdbscan==0.8.29
+      - hearbaseline==2021.1.1
+      - huggingface-hub==0.14.1
+      - hyperpyyaml==1.2.0
+      - importlib-metadata==6.6.0
+      - joblib==1.2.0
+      - julius==0.2.7
+      - keras==2.7.0
+      - keras-preprocessing==1.1.2
+      - libclang==16.0.0
+      - librosa==0.9.1
+      - llvmlite==0.31.0
+      - markdown==3.4.3
+      - nnaudio==0.3.2
+      - numba==0.48.0
+      - numpy==1.19.2
+      - oauthlib==3.2.2
+      - opt-einsum==3.3.0
+      - packaging==23.1
+      - platformdirs==3.5.1
+      - pooch==1.7.0
+      - protobuf==3.19.6
+      - pyasn1==0.5.0
+      - pyasn1-modules==0.3.0
+      - pynndescent==0.5.10
+      - regex==2023.5.5
+      - requests-oauthlib==1.3.1
+      - resampy==0.2.2
+      - rsa==4.9
+      - ruamel-yaml==0.17.26
+      - ruamel-yaml-clib==0.2.7
+      - scikit-learn==1.2.2
+      - scipy==1.9.3
+      - sentencepiece==0.1.99
+      - soundfile==0.12.1
+      - speechbrain==0.5.14
+      - tensorboard==2.13.0
+      - tensorboard-data-server==0.7.0
+      - tensorflow==2.7.4
+      - tensorflow-estimator==2.7.0
+      - tensorflow-io-gcs-filesystem==0.32.0
+      - termcolor==2.3.0
+      - threadpoolctl==3.1.0
+      - tokenizers==0.13.3
+      - torchcrepe==0.0.19
+      - torchopenl3==1.0.1
+      - tqdm==4.65.0
+      - transformers==4.29.2
+      - umap-learn==0.5.3
+      - werkzeug==2.3.4
+      - wrapt==1.15.0
+      - zipp==3.15.0
+prefix: /home/paul.best/miniconda3/envs/hear
diff --git a/new_specie/compute_embeddings.py b/new_specie/compute_embeddings.py
index b25717cd09b095360ea5bd010ea1645ec949ebae..b849716bd2d221cf588502068af5eadaef804204 100755
--- a/new_specie/compute_embeddings.py
+++ b/new_specie/compute_embeddings.py
@@ -17,7 +17,8 @@ parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signa
 args = parser.parse_args()
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
+#frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
+frontend = models.frontend_gibbon
 encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4))
 decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4))
 model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
@@ -25,7 +26,7 @@ model = torch.nn.Sequential(frontend, encoder, decoder).to(device)
 df = pd.read_csv(args.detections)
 
 print('Computing AE projections...')
-loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8)
+loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur, channel=1), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8)
 with torch.no_grad():
     encodings, idxs = [], []
     for x, idx in tqdm(loader):
diff --git a/new_specie/fetch_annot_from_pngs.py b/new_specie/fetch_annot_from_pngs.py
index 0427b3b21cae40e1e0c3f70e887c8dbb6f5e63fb..09fce4652e96a506ad2eda6c6fd3a689bb37fec3 100755
--- a/new_specie/fetch_annot_from_pngs.py
+++ b/new_specie/fetch_annot_from_pngs.py
@@ -12,7 +12,7 @@ parser.add_argument('annot_folder', type=str, help='Name of the folder containin
 parser.add_argument("detections", type=str, help=".csv file with detections that were clustered (labels will be added to it)")
 args = parser.parse_args()
 
-df = pd.read_csv(args.detections))
+df = pd.read_csv(args.detections)
 
 for label in os.listdir(args.annot_folder+'/'):
     for file in os.listdir(f'{args.annot_folder}/{label}/'):
diff --git a/new_specie/filterbank.py b/new_specie/filterbank.py
index 32e9780e198b55c7388d39ec6e386bd6c29c2213..34ecf14f9750b72dec34fa7f9574398216f044ac 100755
--- a/new_specie/filterbank.py
+++ b/new_specie/filterbank.py
@@ -123,6 +123,14 @@ class STFT(nn.Module):
         x = x.reshape((batchsize, channels, -1) + x.shape[2:]) #if channels > 1 else x.reshape((batchsize, -1) + x.shape[2:])
         return x
 
+class MedFilt(nn.Module):
+    """
+    Withdraw the median of each frequency band
+    """
+    def __init__(self):
+        super(MedFilt, self).__init__()
+    def forward(self, x):
+        return x - torch.quantile(x, 0.2, dim=-1, keepdim=True)[0]
 
 
 class TemporalBatchNorm(nn.Module):
@@ -161,3 +169,4 @@ class Log1p(nn.Module):
 
     def extra_repr(self):
         return 'trainable={}'.format(repr(self.trainable))
+
diff --git a/new_specie/models.py b/new_specie/models.py
index 48d7cf84ec44fc40247731fe0c8b15cbbb54e559..f7dcc80573136e04eed5839b15caf38a11c030c7 100755
--- a/new_specie/models.py
+++ b/new_specie/models.py
@@ -21,7 +21,7 @@ frontend = lambda sr, nfft, sampleDur, n_mel : nn.Sequential(
 sparrow_encoder = lambda nfeat, shape : nn.Sequential(
   nn.Conv2d(1, 32, 3, stride=2, bias=False, padding=(1)),
   nn.BatchNorm2d(32),
-  nn.LeakyReLU(0.01),
+  nn.ReLU(True),
   nn.Conv2d(32, 64, 3, stride=2, bias=False, padding=1),
   nn.BatchNorm2d(64),
   nn.ReLU(True),
@@ -31,7 +31,7 @@ sparrow_encoder = lambda nfeat, shape : nn.Sequential(
   nn.Conv2d(128, 256, 3, stride=2, bias=False, padding=1),
   nn.BatchNorm2d(256),
   nn.ReLU(True),
-  nn.Conv2d(256, nfeat, (3, 5), stride=2, padding=(1, 2)),
+  nn.Conv2d(256, nfeat, 3, stride=2, padding=1),
   u.Reshape(nfeat * shape[0] * shape[1])
 )
 
@@ -72,11 +72,8 @@ sparrow_decoder = lambda nfeat, shape : nn.Sequential(
   nn.ReLU(True),
 
   nn.Upsample(scale_factor=2),
-  nn.Conv2d(32, 32, (3, 3), bias=False, padding=1),
-  nn.BatchNorm2d(32),
-  nn.ReLU(True),
   nn.Conv2d(32, 1, (3, 3), bias=False, padding=1),
-  nn.ReLU(True)
+  nn.BatchNorm2d(1),
+  nn.ReLU(True),
+  nn.Conv2d(1, 1, (3, 3), bias=False, padding=1),
 )
-
-
diff --git a/new_specie/requirements.txt b/new_specie/requirements.txt
index 5a9896469de69c28f190fb8131c99b03c41de438..9b25a7728b5ec988c9d7c266b42198f9f1240d0c 100755
--- a/new_specie/requirements.txt
+++ b/new_specie/requirements.txt
@@ -1,12 +1,48 @@
+albumentations==1.3.0
+comet_ml==3.33.3
+coremltools==6.3.0
+Flask==2.3.2
+get_file==0.1.6
+GitPython==3.1.31
+GitPython==3.1.31
 hdbscan==0.8.28
-matplotlib==3.5.1
-numpy==1.22.3
-pandas==1.4.1
-scipy==1.8.0
-sounddevice==0.4.5
+ipython==8.5.0
+matplotlib==3.6.0
+mss==9.0.1
+numpy==1.23.5
+onnx==1.14.0
+onnxruntime==1.15.0
+onnxsim==0.4.28
+opencv_python==4.7.0.72
+openvino==2022.3.0
+paddle==1.0.2
+pafy==0.5.5
+pandas==1.5.0
+Pillow==9.0.1
+Pillow==9.5.0
+plotly==5.13.0
+psutil==5.9.4
+pycocotools==2.0.6
+PyDrive==1.3.1
+PySoundFile==0.9.0.post1
+PyYAML==6.0
+PyYAML==6.0
+Requests==2.31.0
+scipy==1.9.1
+seaborn==0.12.2
+setuptools==67.6.1
+setuptools==52.0.0
+sounddevice==0.4.6
 soundfile==0.11.0
-torch==1.11.0+cu113
-torchvision==0.12.0+cu113
-tqdm==4.64.0
-umap==0.1.1
+tensorflow==2.12.0
+tensorflowjs==4.6.0
+tensorrt==8.6.1
+tflite_runtime==2.12.0
+tflite_support==0.4.3
+thop==0.1.1.post2209072238
+torch==1.12.1+cu113
+torchvision==0.13.1+cu113
+tqdm==4.64.1
+tritonclient==2.33.0
 umap_learn==0.5.3
+x2paddle==1.4.1
diff --git a/new_specie/sort_cluster.py b/new_specie/sort_cluster.py
index 6ab1a705ffff57582999dfc6986d2c7f3105cd6d..2f36c31d77efa6cae9aab14d48b60b643693faf7 100755
--- a/new_specie/sort_cluster.py
+++ b/new_specie/sort_cluster.py
@@ -5,6 +5,7 @@ from tqdm import tqdm
 import matplotlib.pyplot as plt
 import os
 import torch, numpy as np, pandas as pd
+from filterbank import STFT, MelFilter, MedFilt, Log1p
 import hdbscan
 import argparse
 import models
@@ -22,17 +23,18 @@ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFo
 parser.add_argument('encodings', type=str, help='.npy file containing umap projections and their associated index in the detection.pkl table (built using compute_embeddings.py)')
 parser.add_argument('detections', type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
 #parser.add_argument('audio_folder', type=str, help='Path to the folder with complete audio files')
-parser.add_argument("-audio_folder", type=str, default='', help="Folder from which to load sound files")
+parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
 parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
 parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
 parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation")
 parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
 parser.add_argument('-min_cluster_size', type=int, default=10, help='Used for HDBSCAN clustering.')
+parser.add_argument('-channel', type=int, default=0)
 parser.add_argument('-min_sample', type=int, default=5, help='Used for HDBSCAN clustering.')
 parser.add_argument('-eps', type=float, default=0.05, help='Used for HDBSCAN clustering.')
 args = parser.parse_args()
 
-df = pd.read_csv(args.detections)
+df = pd.read_csv(args.detections, index_col=0)
 encodings = np.load(args.encodings, allow_pickle=True).item()
 idxs, umap = encodings['idx'], encodings['umap']
 df.loc[idxs, 'umap_x'] = umap[:,0]
@@ -46,7 +48,14 @@ df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size,
                                 cluster_selection_method='leaf').fit_predict(umap)
 df.cluster = df.cluster.astype(int)
 
-frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
+fs = 44100
+frontend = torch.nn.Sequential(
+  STFT(2048, 256),
+  MelFilter(fs, 2048, 96, 500, 4000),
+  Log1p(4),
+  MedFilt()
+)
+
 
 figscat = plt.figure(figsize=(10, 5))
 plt.title(f'{args.encodings} {args.min_cluster_size} {args.min_sample} {args.eps}')
@@ -73,7 +82,7 @@ class temp():
         dur, fs = info.duration, info.samplerate
         start = int(np.clip(row.pos - args.sampleDur/2, 0, dur - args.sampleDur) * fs)
         sig, fs = sf.read(f'{args.audio_folder}/{row.filename}', start=start, stop=start + int(args.sampleDur*fs), always_2d=True)
-        sig = sig[:,0]
+        sig = sig[:, args.channel]
         if fs != args.SR:
             sig = signal.resample(sig, int(len(sig)/fs*args.SR))
         spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze()
@@ -88,7 +97,7 @@ class temp():
         figscat.canvas.draw()
         # Play the audio
         if soundAvailable:
-            sd.play(sig*10, fs)
+            sd.play(sig, fs)
 
 mtemp = temp()
 cid = figscat.canvas.mpl_connect('button_press_event', mtemp.onclick)
@@ -108,7 +117,7 @@ for c, grp in df.groupby('cluster'):
     loader = torch.utils.data.DataLoader(u.Dataset(grp.sample(min(len(grp), 200)), args.audio_folder, args.SR, args.sampleDur), batch_size=1, num_workers=8)
     with torch.no_grad():
         for x, idx in tqdm(loader, leave=False, desc=str(c)):
-            plt.specgram(x.squeeze().numpy())
-            plt.tight_layout()
+            plt.imshow(frontend(x).squeeze().numpy(), origin='lower', aspect='auto')
+            plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
             plt.savefig(f'cluster_pngs/{c}/{idx.squeeze().item()}')
             plt.close()
diff --git a/new_specie/utils.py b/new_specie/utils.py
index 5c6644f6546275418b03cf9741321b4de758a0d6..09acb635fbd5b8f5b07d63910e52731cf8d5cb38 100755
--- a/new_specie/utils.py
+++ b/new_specie/utils.py
@@ -10,9 +10,9 @@ def collate_fn(batch):
     return dataloader.default_collate(batch)
 
 class Dataset(Dataset):
-    def __init__(self, df, audiopath, sr, sampleDur):
+    def __init__(self, df, audiopath, sr, sampleDur, channel=0):
         super(Dataset, self)
-        self.audiopath, self.df, self.sr, self.sampleDur = audiopath, df, sr, sampleDur
+        self.audiopath, self.df, self.sr, self.sampleDur, self.channel = audiopath, df, sr, sampleDur, channel
 
     def __len__(self):
         return len(self.df)
@@ -24,9 +24,9 @@ class Dataset(Dataset):
             dur, fs = info.duration, info.samplerate
             start = int(np.clip(row.pos - self.sampleDur/2, 0, max(0, dur - self.sampleDur)) * fs)
             sig, fs = sf.read(self.audiopath+'/'+row.filename, start=start, stop=start + int(self.sampleDur*fs), always_2d=True)
-            sig = sig[:,0]
-        except:
-            print(f'Failed to load sound from row {row.name} with filename {row.filename}')
+            sig = sig[:, row.Channel -1 if 'Channel' in row else self.channel]
+        except Exception as e:
+            print(f'Failed to load sound from row {row.name} with filename {row.filename}', e)
             return None
         if len(sig) < self.sampleDur * fs:
             sig = np.concatenate([sig, np.zeros(int(self.sampleDur * fs) - len(sig))])
diff --git a/PCEN_pytorch.py b/paper_experiments/PCEN_pytorch.py
similarity index 100%
rename from PCEN_pytorch.py
rename to paper_experiments/PCEN_pytorch.py
diff --git a/bengalese_finch1/bengalese_finch1.csv b/paper_experiments/bengalese_finch1/bengalese_finch1.csv
similarity index 100%
rename from bengalese_finch1/bengalese_finch1.csv
rename to paper_experiments/bengalese_finch1/bengalese_finch1.csv
diff --git a/bengalese_finch1/citation.txt b/paper_experiments/bengalese_finch1/citation.txt
similarity index 100%
rename from bengalese_finch1/citation.txt
rename to paper_experiments/bengalese_finch1/citation.txt
diff --git a/bengalese_finch2/bengalese_finch2.csv b/paper_experiments/bengalese_finch2/bengalese_finch2.csv
similarity index 100%
rename from bengalese_finch2/bengalese_finch2.csv
rename to paper_experiments/bengalese_finch2/bengalese_finch2.csv
diff --git a/bengalese_finch2/citation.txt b/paper_experiments/bengalese_finch2/citation.txt
similarity index 100%
rename from bengalese_finch2/citation.txt
rename to paper_experiments/bengalese_finch2/citation.txt
diff --git a/black-headed_grosbeaks/black-headed_grosbeaks.csv b/paper_experiments/black-headed_grosbeaks/black-headed_grosbeaks.csv
similarity index 100%
rename from black-headed_grosbeaks/black-headed_grosbeaks.csv
rename to paper_experiments/black-headed_grosbeaks/black-headed_grosbeaks.csv
diff --git a/black-headed_grosbeaks/filelist.txt b/paper_experiments/black-headed_grosbeaks/filelist.txt
similarity index 100%
rename from black-headed_grosbeaks/filelist.txt
rename to paper_experiments/black-headed_grosbeaks/filelist.txt
diff --git a/black-headed_grosbeaks/source.txt b/paper_experiments/black-headed_grosbeaks/source.txt
similarity index 100%
rename from black-headed_grosbeaks/source.txt
rename to paper_experiments/black-headed_grosbeaks/source.txt
diff --git a/california_thrashers/california_thrashers.csv b/paper_experiments/california_thrashers/california_thrashers.csv
similarity index 100%
rename from california_thrashers/california_thrashers.csv
rename to paper_experiments/california_thrashers/california_thrashers.csv
diff --git a/california_thrashers/filelist.txt b/paper_experiments/california_thrashers/filelist.txt
similarity index 100%
rename from california_thrashers/filelist.txt
rename to paper_experiments/california_thrashers/filelist.txt
diff --git a/california_thrashers/source.txt b/paper_experiments/california_thrashers/source.txt
similarity index 100%
rename from california_thrashers/source.txt
rename to paper_experiments/california_thrashers/source.txt
diff --git a/cassin_vireo/cassin_vireo.csv b/paper_experiments/cassin_vireo/cassin_vireo.csv
similarity index 100%
rename from cassin_vireo/cassin_vireo.csv
rename to paper_experiments/cassin_vireo/cassin_vireo.csv
diff --git a/cassin_vireo/filelist.txt b/paper_experiments/cassin_vireo/filelist.txt
similarity index 100%
rename from cassin_vireo/filelist.txt
rename to paper_experiments/cassin_vireo/filelist.txt
diff --git a/cassin_vireo/source.txt b/paper_experiments/cassin_vireo/source.txt
similarity index 100%
rename from cassin_vireo/source.txt
rename to paper_experiments/cassin_vireo/source.txt
diff --git a/compute_embeddings.py b/paper_experiments/compute_embeddings.py
similarity index 100%
rename from compute_embeddings.py
rename to paper_experiments/compute_embeddings.py
diff --git a/dolphin/dolphin.csv b/paper_experiments/dolphin/dolphin.csv
similarity index 100%
rename from dolphin/dolphin.csv
rename to paper_experiments/dolphin/dolphin.csv
diff --git a/filterbank.py b/paper_experiments/filterbank.py
similarity index 100%
rename from filterbank.py
rename to paper_experiments/filterbank.py
diff --git a/good_species.txt b/paper_experiments/good_species.txt
similarity index 100%
rename from good_species.txt
rename to paper_experiments/good_species.txt
diff --git a/hdbscan_gridsearch.py b/paper_experiments/hdbscan_gridsearch.py
similarity index 100%
rename from hdbscan_gridsearch.py
rename to paper_experiments/hdbscan_gridsearch.py
diff --git a/humpback/humpback.csv b/paper_experiments/humpback/humpback.csv
similarity index 100%
rename from humpback/humpback.csv
rename to paper_experiments/humpback/humpback.csv
diff --git a/humpback2/extract_annot.py b/paper_experiments/humpback2/extract_annot.py
similarity index 100%
rename from humpback2/extract_annot.py
rename to paper_experiments/humpback2/extract_annot.py
diff --git a/humpback2/humpback2.csv b/paper_experiments/humpback2/humpback2.csv
similarity index 100%
rename from humpback2/humpback2.csv
rename to paper_experiments/humpback2/humpback2.csv
diff --git a/models.py b/paper_experiments/models.py
similarity index 100%
rename from models.py
rename to paper_experiments/models.py
diff --git a/plot_annot_distrib.py b/paper_experiments/plot_annot_distrib.py
similarity index 100%
rename from plot_annot_distrib.py
rename to paper_experiments/plot_annot_distrib.py
diff --git a/plot_clusters.py b/paper_experiments/plot_clusters.py
similarity index 65%
rename from plot_clusters.py
rename to paper_experiments/plot_clusters.py
index e71515f124b9a60d94eefea46c14738fa542e5e2..511bdfb7c43973160da9172bff7414cd5206c53c 100755
--- a/plot_clusters.py
+++ b/paper_experiments/plot_clusters.py
@@ -9,19 +9,22 @@ fig, ax = plt.subplots(nrows=len(species), figsize=(7, 10), dpi=200)
 for i, specie in enumerate(species):
     meta = models.meta[specie]
     frontend = models.frontend['pcenMel'](meta['sr'], meta['nfft'], meta['sampleDur'], 128)
-    dic = np.load(f'{specie}/encodings//encodings_{specie}_256_pcenMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item()
-    idxs, X = dic['idxs'], dic['umap']
+    dic = np.load(f'{specie}/encodings//encodings_{specie}_256_logMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item()
+    idxs, X = dic['idxs'], dic['umap8']
     df = pd.read_csv(f'{specie}/{specie}.csv')
-    clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.05, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X)
+    clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf' if not 'humpback' in specie else 'eom').fit_predict(X)
     df.loc[idxs, 'cluster'] = clusters.astype(int)
     for j, cluster in enumerate(np.random.choice(np.arange(df.cluster.max()), 4)):
-        for k, (x, name) in enumerate(torch.utils.data.DataLoader(u.Dataset(df[df.cluster==cluster].sample(8), f'{specie}/audio/', meta['sr'], meta['sampleDur']), batch_size=1)):
+        loader = torch.utils.data.DataLoader(u.Dataset(df[df.cluster==cluster].sample(8), f'{specie}/audio/', meta['sr'], meta['sampleDur']), batch_size=1)
+        for k, (x, name) in enumerate(loader):
             spec = frontend(x).squeeze().numpy()
             ax[i].imshow(spec, extent=[k, k+1, j, j+1], origin='lower', aspect='auto', cmap='Greys', vmin=np.quantile(spec, .2), vmax=np.quantile(spec, .98))
     ax[i].set_xticks([])
     ax[i].set_yticks([])
+    if specie == 'humpback2':
+        specie = 'humpback\n(small)'
     # ax[i].grid(color='w', xdata=np.arange(1, 10), ydata=np.arange(1, 5))
-    ax[i].set_ylabel(specie.replace('_', ' '))
+    ax[i].set_ylabel(specie.replace('_', '\n'))
     ax[i].set_xlim(0, 8)
     ax[i].set_ylim(0, 4)
     ax[i].vlines(np.arange(1, 8), 0, 4, linewidths=1, color='black')
diff --git a/plot_main_results.py b/paper_experiments/plot_main_results.py
similarity index 50%
rename from plot_main_results.py
rename to paper_experiments/plot_main_results.py
index 75de2a853666619347d35ae92f30b20c939fd857..b766250aed36dcf63d6414b900908843b6708ea8 100755
--- a/plot_main_results.py
+++ b/paper_experiments/plot_main_results.py
@@ -9,8 +9,10 @@ all_frontendNames = [['AE prcptl', 'AE MSE', 'Spectro.', 'PAFs'][::-1],
                      ['AE', 'gen. AE', 'OpenL3', 'Wav2Vec2', 'CREPE'][::-1],
                      ['log-Mel', 'Mel', 'PCEN-Mel', 'log-STFT']]
 
+# Bar plots of NMI depending on the choice of feature extraction
 all_plotNames = ['handcraft', 'deepembed', 'frontends']
 for frontends, frontendNames, plotName in zip(all_frontends, all_frontendNames, all_plotNames):
+    print(plotName)
     df = pd.read_csv('hdbscan_HP.csv')
     df.loc[df.ms.isna(), 'ms'] = 0
     best = df.loc[df.groupby(["specie", 'frontend']).nmi.idxmax()]
@@ -29,3 +31,46 @@ for frontends, frontendNames, plotName in zip(all_frontends, all_frontendNames,
     plt.grid(axis='y')
     plt.tight_layout()
     plt.savefig(f'NMIs_hdbscan_barplot_{plotName}.pdf')
+
+
+# Plot of NMI depending on UMAP nb of components
+fig, ax = plt.subplots(ncols=2, sharey=True, sharex=True, figsize=(10, 3.5))
+# Left panel with auto-encoder feature extraction
+df = pd.read_csv("hdbscan_HP_archive2.csv")
+df.loc[df.ms.isna(), "ms"] = 0
+df = df[((df.frontend == "256_logMel128")&(df.al=='leaf')&(df.mcs==10)&(df.ms==3)&(df.eps==.1))].sort_values('specie')
+#df = df[df.frontend == "256_logMel128"].sort_values('specie')
+for s, grp in df.groupby("specie"):
+    ax[0].plot(
+        np.arange(5),
+        [grp[grp.ncomp == n].nmi.max() / grp.nmi.max() for n in 2 ** np.arange(1, 6)],
+    )
+# Right panel with Spectrogram feature extraction
+df = pd.read_csv("hdbscan_HP_archive2.csv")
+df.loc[df.ms.isna(), "ms"] = 0
+df = df[((df.frontend == "spec32")&(df.al=='leaf')&(df.mcs==10)&(df.ms==3)&(df.eps==.1))].sort_values('specie')
+#df = df[df.frontend == "spec32"].sort_values('specie')
+for s, grp in df.groupby("specie"):
+    if s.endswith('s'):
+        s = s[:-1]
+    if s == 'humpback2':
+        s = 'humpback (small)'
+    ax[1].plot(
+        np.arange(5),
+        [grp[grp.ncomp == n].nmi.max() / grp.nmi.max() for n in 2 ** np.arange(1, 6)],
+        label=s.replace('_',' '),
+    )
+plt.legend()
+ax[0].set_ylabel("NMI / max(NMI)")
+ax[1].set_ylim(0.85, 1.01)
+ax[0].grid()
+ax[1].grid()
+ax[0].set_title("auto-encoder")
+ax[1].set_title("spectrogram")
+ax[0].set_xlabel("# UMAP dimensions")
+ax[1].set_xlabel("# UMAP dimensions")
+ax[0].set_xticks(np.arange(5))
+ax[0].set_xticklabels(2 ** np.arange(1, 6))
+ax[0].set_yticks(np.arange(0.85, 1.01, 0.05))
+plt.tight_layout()
+plt.savefig('NMI_umap.pdf')
diff --git a/plot_prec_rec.py b/paper_experiments/plot_prec_rec.py
similarity index 95%
rename from plot_prec_rec.py
rename to paper_experiments/plot_prec_rec.py
index 22f38f797ae156331408c01fbb871f29c3e9d6f9..7d4615488947af78e66243f22a29053f9f9dc825 100755
--- a/plot_prec_rec.py
+++ b/paper_experiments/plot_prec_rec.py
@@ -22,12 +22,12 @@ for specie in species:
     dic = np.load(f'{specie}/encodings//encodings_{specie}_256_logMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item()
     idxs, X = dic['idxs'], dic['umap8']
     df = pd.read_csv(f'{specie}/{specie}.csv')
-    clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf' if not 'humpbacjjk' in specie else 'eom').fit_predict(X)
+    clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf' if not 'humpback' in specie else 'eom').fit_predict(X)
     df.loc[idxs, 'cluster'] = clusters.astype(int)
 
     dic = np.load(f'{specie}/encodings/encodings_spec32.npy', allow_pickle=True).item()
     idxs, X = dic['idxs'], dic['umap8']
-    clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf' if not 'humpbacjjk' in specie else 'eom').fit_predict(X)
+    clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf' if not 'humpback' in specie else 'eom').fit_predict(X)
     df.loc[idxs, 'cluster2'] = clusters.astype(int)
 
     mask = ~df.loc[idxs].label.isna()
diff --git a/plot_projections.py b/paper_experiments/plot_projections.py
similarity index 100%
rename from plot_projections.py
rename to paper_experiments/plot_projections.py
diff --git a/run_baseline.py b/paper_experiments/run_PAF_baseline.py
similarity index 100%
rename from run_baseline.py
rename to paper_experiments/run_PAF_baseline.py
diff --git a/run_hearbaseline.py b/paper_experiments/run_hearbaseline.py
similarity index 100%
rename from run_hearbaseline.py
rename to paper_experiments/run_hearbaseline.py
diff --git a/run_spec32_baseline.py b/paper_experiments/run_spec32_baseline.py
similarity index 100%
rename from run_spec32_baseline.py
rename to paper_experiments/run_spec32_baseline.py
diff --git a/test_AE.py b/paper_experiments/test_AE.py
similarity index 100%
rename from test_AE.py
rename to paper_experiments/test_AE.py
diff --git a/test_AE_all.py b/paper_experiments/test_AE_all.py
similarity index 100%
rename from test_AE_all.py
rename to paper_experiments/test_AE_all.py
diff --git a/train_AE.py b/paper_experiments/train_AE.py
similarity index 100%
rename from train_AE.py
rename to paper_experiments/train_AE.py
diff --git a/train_AE_all.py b/paper_experiments/train_AE_all.py
similarity index 100%
rename from train_AE_all.py
rename to paper_experiments/train_AE_all.py
diff --git a/utils.py b/paper_experiments/utils.py
similarity index 100%
rename from utils.py
rename to paper_experiments/utils.py
diff --git a/plot_results_kmeans.py b/plot_results_kmeans.py
deleted file mode 100755
index a60bdb761eb54e3de15ab9d1465d9219716a8ac9..0000000000000000000000000000000000000000
--- a/plot_results_kmeans.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import hdbscan
-import pandas as pd
-import matplotlib.pyplot as plt
-import numpy as np
-from sklearn import metrics, cluster
-from scipy.stats import linregress
-
-species = np.loadtxt('good_species.txt', dtype=str)
-frontends = ['16_pcenMel128', '16_logMel128', '16_logSTFT', '16_Mel128', '8_pcenMel64', '32_pcenMel128']
-plt.figure()
-for specie in species:
-    df = pd.read_csv(f'{specie}/{specie}.csv')
-    nmis = []
-    for i, frontend in enumerate(frontends):
-        print(specie, frontend)
-        dic = np.load(f'{specie}/encodings_{specie}_{frontend}_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item()
-        idxs, encodings, X = dic['idxs'], dic['encodings'], dic['umap']
-
-        ks = (5*1.2**np.arange(20)).astype(int)
-        distorsions = [cluster.KMeans(n_clusters=k).fit(encodings).inertia_ for k in ks]
-        errors = [linregress(ks[:i], distorsions[:i]).stderr + linregress(ks[i+1:], distorsions[i+1:]).stderr for i in range(2, len(ks)-2)]
-        k = ks[np.argmin(errors)]
-        clusters = cluster.KMeans(n_clusters=k).fit_predict(encodings)
-        df.loc[idxs, 'cluster'] = clusters.astype(int)
-
-        mask = ~df.loc[idxs].label.isna()
-        clusters, labels = clusters[mask], df.loc[idxs[mask]].label
-        nmis.append(metrics.normalized_mutual_info_score(labels, clusters))
-        df.drop('cluster', axis=1, inplace=True)
-    plt.scatter(nmis, np.arange(len(frontends)), label=specie)
-
-plt.yticks(range(len(frontends)), frontends)
-plt.ylabel('archi')
-plt.xlabel('NMI with expert labels')
-plt.grid()
-plt.legend()
-plt.tight_layout()
-plt.savefig('NMIs_kmeans.pdf')
diff --git a/print_annot.py b/print_annot.py
deleted file mode 100755
index da77b06dc33611633b0de7e461b9dac729effc71..0000000000000000000000000000000000000000
--- a/print_annot.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import os, shutil, argparse
-from tqdm import tqdm
-import matplotlib.pyplot as plt
-import pandas as pd, numpy as np
-import models, utils as u
-import torch
-
-parser = argparse.ArgumentParser()
-parser.add_argument("specie", type=str)
-parser.add_argument("-frontend", type=str, default='logMel')
-parser.add_argument("-nMel", type=int, default=128)
-args = parser.parse_args()
-
-meta = models.meta[args.specie]
-df = pd.read_csv(f'{args.specie}/{args.specie}.csv')
-frontend = models.frontend[args.frontend](meta['sr'], meta['nfft'], meta['sampleDur'], args.nMel)
-shutil.rmtree(f'{args.specie}/annot_pngs', ignore_errors=True)
-
-for label, grp in df.groupby('label'):
-    os.makedirs(f'{args.specie}/annot_pngs/{label}', exist_ok=True)
-    loader = torch.utils.data.DataLoader(u.Dataset(grp, args.specie+'/audio/', meta['sr'], meta['sampleDur']),\
-                                         batch_size=1, num_workers=4, pin_memory=True)
-    for x, idx in tqdm(loader, desc=args.specie + ' ' + label, leave=False):
-        x = frontend(x).squeeze().detach()
-        assert not torch.isnan(x).any(), "Found a NaN in spectrogram... :/"
-        plt.figure()
-        plt.imshow(x, origin='lower', aspect='auto', cmap='Greys', vmin=np.quantile(x, .7))
-        plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
-        row = df.loc[idx.item()]
-        #plt.savefig(f'{args.specie}/annot_pngs/{label}/{row.fn.split(".")[0]}_{row.pos:.2f}.png')
-        plt.savefig(f'{args.specie}/annot_pngs/{label}/{idx.item()}')
-        plt.close()
-
-
-
diff --git a/print_reconstr.py b/print_reconstr.py
deleted file mode 100755
index 66efcbf29cf6e88038362fd63b9ed3efd7fb1068..0000000000000000000000000000000000000000
--- a/print_reconstr.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import os, shutil, argparse
-from tqdm import tqdm
-import matplotlib.pyplot as plt
-import pandas as pd, numpy as np
-import models, utils as u
-import torch
-
-parser = argparse.ArgumentParser()
-parser.add_argument("specie", type=str)
-parser.add_argument("-bottleneck", type=int, default=16)
-parser.add_argument("-frontend", type=str, default='logMel')
-parser.add_argument("-encoder", type=str, default='sparrow_encoder')
-parser.add_argument("-prcptl", type=int, default=1)
-parser.add_argument("-nMel", type=int, default=128)
-args = parser.parse_args()
-
-modelname = f'{args.specie}_{args.bottleneck}_{args.frontend}{args.nMel if "Mel" in args.frontend else ""}_{args.encoder}_decod2_BN_nomaxPool{"_noprcptl" if args.prcptl==0 else ""}.stdc'
-print(modelname)
-gpu = torch.device(f'cuda')
-
-meta = models.meta[args.specie]
-frontend = models.frontend[args.frontend](meta['sr'], meta['nfft'], meta['sampleDur'], args.nMel)
-encoder = models.__dict__[args.encoder](*((args.bottleneck // 16, (4, 4)) if args.nMel == 128 else (args.bottleneck // 8, (2, 4))))
-decoder = models.sparrow_decoder(args.bottleneck, (4, 4) if args.nMel == 128 else (2, 4))
-model = torch.nn.Sequential(frontend, encoder, decoder).to(gpu)
-model.load_state_dict(torch.load(f'{args.specie}/weights/{modelname}'))
-model.eval()
-
-df = pd.read_csv(f'{args.specie}/{args.specie}.csv')
-
-shutil.rmtree(f'{args.specie}/reconstruct_pngs', ignore_errors=True)
-
-for label, grp in df.groupby('label'):
-    os.makedirs(f'{args.specie}/reconstruct_pngs/{label}', exist_ok=True)
-    loader = torch.utils.data.DataLoader(u.Dataset(grp, args.specie+'/audio/', meta['sr'], meta['sampleDur']),\
-                                         batch_size=1, num_workers=4, pin_memory=True)
-    with torch.no_grad():
-        for x, idx in tqdm(loader, desc=args.specie + ' ' + label, leave=False):
-            x = model(x.to(gpu)).squeeze().detach().cpu()
-            assert not torch.isnan(x).any(), "Found a NaN in spectrogram... :/"
-            plt.imshow(x, origin='lower', aspect='auto') #, cmap='Greys', vmin=np.quantile(x, .7))
-            plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
-            row = df.loc[idx.item()]
-            plt.savefig(f'{args.specie}/reconstruct_pngs/{label}/{idx.item()}')
-            plt.close()
-
-
-
diff --git a/sort_cluster.py b/sort_cluster.py
deleted file mode 100755
index ac8d92e854dff183832deb875019cdb0b5eee7b7..0000000000000000000000000000000000000000
--- a/sort_cluster.py
+++ /dev/null
@@ -1,96 +0,0 @@
-import utils as u
-from tqdm import tqdm
-import torch
-import matplotlib.pyplot as plt
-import soundfile as sf
-import models
-import os
-import numpy as np
-import pandas as pd
-import hdbscan
-import argparse
-try:
-    import sounddevice as sd
-    soundAvailable = True
-except:
-    soundAvailable = False
-
-parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='blabla')
-parser.add_argument('encodings', type=str)
-parser.add_argument('--min_cluster_size', type=int, default=50)
-parser.add_argument('--min_sample', type=int, default=10)
-parser.add_argument('--eps', type=float, default=1)
-args = parser.parse_args()
-
-gpu = torch.device('cuda:0')
-frontend = models.get['frontend_logMel'].to(gpu)
-
-a = np.load(args.encodings, allow_pickle=True).item()
-df = pd.read_pickle('detections.pkl')
-idxs, umap = a['idx'], a['umap']
-
-# cluster the embedings (min_cluster_size and min_samples parameters need to be tuned)
-df.loc[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size,
-                                          min_samples=args.min_sample,
-                                          core_dist_n_jobs=-1,
-#                                          cluster_selection_epsilon=args.eps,
-                                          cluster_selection_method='leaf').fit_predict(umap)
-df.loc[idxs, ['umap_x', 'umap_y']] = umap
-df.cluster = df.cluster.astype(int)
-
-figscat = plt.figure(figsize=(20, 10))
-plt.title(f'{args.encodings} {args.min_cluster_size} {args.min_sample} {args.eps}')
-for c, grp in df.groupby('cluster'):
-    plt.scatter(grp.umap_x, grp.umap_y, s=3, alpha=.1, c='grey' if c == -1 else None)
-#plt.scatter(df[((df.type!='Vocalization')&(~df.type.isna()))].umap_x, df[((df.type!='Vocalization')&(~df.type.isna()))].umap_y, marker='x')
-plt.tight_layout()
-axScat = figscat.axes[0]
-plt.savefig('projection')
-figSpec = plt.figure()
-plt.scatter(0, 0)
-axSpec = figSpec.axes[0]
-
-print(df.groupby('cluster').count())
-
-class temp():
-    def __init__(self):
-        self.row = ""
-    def onclick(self, event):
-        #get row
-        left, right, bottom, top = axScat.get_xlim()[0], axScat.get_xlim()[1], axScat.get_ylim()[0], axScat.get_ylim()[1]
-        rangex, rangey =  right - left, top - bottom
-        closest = (np.sqrt(((df.umap_x - event.xdata)/rangex)**2 + ((df.umap_y  - event.ydata)/rangey)**2)).idxmin()
-        sig, fs = sf.read(f'/data_ssd/marmossets/{closest}.wav')
-        spec = frontend(torch.Tensor(sig).to(gpu).view(1, -1).float()).detach().cpu().squeeze()
-        axSpec.imshow(spec, origin='lower', aspect='auto')
-        row = df.loc[closest]
-        axSpec.set_title(f'{closest}, cluster {row.cluster} ({(df.cluster==row.cluster).sum()} points)')
-        axScat.scatter(row.umap_x, row.umap_y, c='r')
-        axScat.set_xlim(left, right)
-        axScat.set_ylim(bottom, top)
-        figSpec.canvas.draw()
-        figscat.canvas.draw()
-        if soundAvailable:
-            sd.play(sig*10, fs)
-
-mtemp = temp()
-
-cid = figscat.canvas.mpl_connect('button_press_event', mtemp.onclick)
-
-plt.show()
-
-if input('print cluster pngs ??') != 'y':
-    exit()
-
-os.system('rm -R cluster_pngs/*')
-for c, grp in df.groupby('cluster'):
-    if c == -1 or len(grp) > 10_000:
-        continue
-    os.system('mkdir cluster_pngs/'+str(c))
-    with torch.no_grad():
-        for x, idx in tqdm(torch.utils.data.DataLoader(u.Dataset(grp.sample(200), sampleDur=.5), batch_size=1, num_workers=8), leave=False, desc=str(c)):
-            x = x.to(gpu)
-            x = frontend(x).cpu().detach().squeeze()
-            plt.imshow(x, origin='lower', aspect='auto')
-            plt.savefig(f'cluster_pngs/{c}/{idx.squeeze().item()}')
-            plt.close()