diff --git a/.gitignore b/.gitignore index 76ccee7b16358c65b6c72e2d84db0c85d666f015..3049b72a7212aba0d198157d05cd57ae56780eae 100755 --- a/.gitignore +++ b/.gitignore @@ -1,17 +1,18 @@ *.png *.stdc *.npy -*/audio +paper_experiments/*/audio */TextGrid *__pycache__ *log -humpback2/annots +paper_experiments/humpback2/annots gibbon new_specie/*/ otter/pone.0112562.s003.xlsx zebra_finch/Library_notes.pdf annot_distrib.pdf annot_distrib.tex -humpback/annot +paper_experiments/humpback/annot humpback_CARIMAM/ -dolphin/zips +paper_experiments/dolphin/zips +archive/* diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..e3f14220ccbd6964cf06f183f87fd708d9ea595a --- /dev/null +++ b/environment.yml @@ -0,0 +1,240 @@ +name: hear +channels: + - pytorch-nightly + - nvidia + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - asttokens=2.0.5=pyhd3eb1b0_0 + - backcall=0.2.0=pyhd3eb1b0_0 + - blas=1.0=mkl + - bottleneck=1.3.5=py38h7deecbd_0 + - brotli=1.0.9=h5eee18b_7 + - brotli-bin=1.0.9=h5eee18b_7 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2023.01.10=h06a4308_0 + - certifi=2023.5.7=py38h06a4308_0 + - cffi=1.15.1=py38h5eee18b_3 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - contourpy=1.0.5=py38hdb19cb5_0 + - cryptography=39.0.1=py38h9ce1e76_0 + - cuda-cudart=12.1.105=0 + - cuda-cupti=12.1.105=0 + - cuda-libraries=12.1.0=0 + - cuda-nvrtc=12.1.105=0 + - cuda-nvtx=12.1.105=0 + - cuda-opencl=12.1.105=0 + - cuda-runtime=12.1.0=0 + - cycler=0.11.0=pyhd3eb1b0_0 + - dbus=1.13.18=hb2f20db_0 + - decorator=5.1.1=pyhd3eb1b0_0 + - executing=0.8.3=pyhd3eb1b0_0 + - expat=2.4.9=h6a678d5_0 + - ffmpeg=4.2.2=h20bf706_0 + - filelock=3.9.0=py38h06a4308_0 + - fontconfig=2.14.1=h4c34cd2_2 + - fonttools=4.25.0=pyhd3eb1b0_0 + - freetype=2.12.1=h4a9f257_0 + - giflib=5.2.1=h5eee18b_3 + - glib=2.69.1=he621ea3_2 + - gmp=6.2.1=h295c915_3 + - gmpy2=2.1.2=py38heeb90bb_0 + - gnutls=3.6.15=he1e5248_0 + - gst-plugins-base=1.14.1=h6a678d5_1 + - gstreamer=1.14.1=h5eee18b_1 + - icu=58.2=he6710b0_3 + - idna=3.4=py38h06a4308_0 + - importlib_metadata=6.0.0=hd3eb1b0_0 + - importlib_resources=5.2.0=pyhd3eb1b0_1 + - intel-openmp=2023.1.0=hdb19cb5_46305 + - ipython=8.12.0=py38h06a4308_0 + - jedi=0.18.1=py38h06a4308_1 + - jinja2=3.1.2=py38h06a4308_0 + - jpeg=9e=h5eee18b_1 + - kiwisolver=1.4.4=py38h6a678d5_0 + - krb5=1.19.4=h568e23c_0 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libbrotlicommon=1.0.9=h5eee18b_7 + - libbrotlidec=1.0.9=h5eee18b_7 + - libbrotlienc=1.0.9=h5eee18b_7 + - libclang13=14.0.6=default_he11475f_1 + - libcublas=12.1.0.26=0 + - libcufft=11.0.2.4=0 + - libcufile=1.6.1.9=0 + - libcurand=10.3.2.106=0 + - libcusolver=11.4.4.55=0 + - libcusparse=12.0.2.55=0 + - libdeflate=1.17=h5eee18b_0 + - libedit=3.1.20221030=h5eee18b_0 + - libevent=2.1.12=h8f2d780_0 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgfortran-ng=7.5.0=ha8ba4b0_17 + - libgfortran4=7.5.0=ha8ba4b0_17 + - libgomp=11.2.0=h1234567_1 + - libidn2=2.3.4=h5eee18b_0 + - libllvm14=14.0.6=hdb19cb5_3 + - libnpp=12.0.2.50=0 + - libnvjitlink=12.1.105=0 + - libnvjpeg=12.1.0.39=0 + - libopus=1.3.1=h7b6447c_0 + - libpng=1.6.39=h5eee18b_0 + - libpq=12.9=h16c4e8d_3 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.19.0=h5eee18b_0 + - libtiff=4.5.0=h6a678d5_2 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=1.41.5=h5eee18b_0 + - libvpx=1.7.0=h439df22_0 + - libwebp=1.2.4=h11a3e52_1 + - libwebp-base=1.2.4=h5eee18b_1 + - libxcb=1.15=h7f8727e_0 + - libxkbcommon=1.0.1=h5eee18b_1 + - libxml2=2.10.3=hcbfbd50_0 + - libxslt=1.1.37=h2085143_0 + - lz4-c=1.9.4=h6a678d5_0 + - markupsafe=2.1.1=py38h7f8727e_0 + - matplotlib=3.7.1=py38h06a4308_1 + - matplotlib-base=3.7.1=py38h417a72b_1 + - matplotlib-inline=0.1.6=py38h06a4308_0 + - mkl=2023.1.0=h6d00ec8_46342 + - mkl-service=2.4.0=py38h5eee18b_1 + - mkl_fft=1.3.6=py38h417a72b_1 + - mkl_random=1.2.2=py38h417a72b_1 + - mpc=1.1.0=h10f8cd9_1 + - mpfr=4.0.2=hb69a4c5_1 + - mpmath=1.2.1=py38h06a4308_0 + - munkres=1.1.4=py_0 + - ncurses=6.4=h6a678d5_0 + - nettle=3.7.3=hbbd107a_1 + - networkx=2.8.4=py38h06a4308_1 + - nspr=4.33=h295c915_0 + - nss=3.74=h0370c37_0 + - numexpr=2.8.4=py38hc78ab66_1 + - numpy-base=1.24.3=py38h060ed82_1 + - openh264=2.1.1=h4ff587b_0 + - openssl=1.1.1t=h7f8727e_0 + - pandas=1.5.3=py38h417a72b_0 + - parso=0.8.3=pyhd3eb1b0_0 + - pcre=8.45=h295c915_0 + - pexpect=4.8.0=pyhd3eb1b0_3 + - pickleshare=0.7.5=pyhd3eb1b0_1003 + - pillow=9.4.0=py38h6a678d5_0 + - pip=23.0.1=py38h06a4308_0 + - ply=3.11=py38_0 + - prompt-toolkit=3.0.36=py38h06a4308_0 + - ptyprocess=0.7.0=pyhd3eb1b0_2 + - pure_eval=0.2.2=pyhd3eb1b0_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pygments=2.15.1=py38h06a4308_1 + - pyopenssl=23.0.0=py38h06a4308_0 + - pyparsing=3.0.9=py38h06a4308_0 + - pyqt=5.15.7=py38h6a678d5_1 + - pyqt5-sip=12.11.0=py38h6a678d5_1 + - pysocks=1.7.1=py38h06a4308_0 + - python=3.8.16=h7a1cb2a_3 + - python-dateutil=2.8.2=pyhd3eb1b0_0 + - pytorch=2.1.0.dev20230522=py3.8_cuda12.1_cudnn8.8.1_0 + - pytorch-cuda=12.1=ha16c6d3_5 + - pytorch-mutex=1.0=cuda + - pytz=2022.7=py38h06a4308_0 + - pyyaml=6.0=py38h5eee18b_1 + - qt-main=5.15.2=h8373d8f_8 + - qt-webengine=5.15.9=hbbf29b9_6 + - qtwebkit=5.212=h3fafdc1_5 + - readline=8.2=h5eee18b_0 + - requests=2.29.0=py38h06a4308_0 + - setuptools=66.0.0=py38h06a4308_0 + - sip=6.6.2=py38h6a678d5_0 + - six=1.16.0=pyhd3eb1b0_1 + - sqlite=3.41.2=h5eee18b_0 + - stack_data=0.2.0=pyhd3eb1b0_0 + - sympy=1.11.1=py38h06a4308_0 + - tbb=2021.8.0=hdb19cb5_0 + - tk=8.6.12=h1ccaba5_0 + - toml=0.10.2=pyhd3eb1b0_0 + - torchaudio=2.1.0.dev20230522=py38_cu121 + - torchtriton=2.1.0+7d1a95b046=py38 + - torchvision=0.16.0.dev20230522=py38_cu121 + - tornado=6.2=py38h5eee18b_0 + - traitlets=5.7.1=py38h06a4308_0 + - typing_extensions=4.5.0=py38h06a4308_0 + - urllib3=1.25.8=py38_0 + - wcwidth=0.2.5=pyhd3eb1b0_0 + - wheel=0.38.4=py38h06a4308_0 + - x264=1!157.20191217=h7b6447c_0 + - xz=5.4.2=h5eee18b_0 + - yaml=0.2.5=h7b6447c_0 + - zlib=1.2.13=h5eee18b_0 + - zstd=1.5.5=hc292b87_0 + - pip: + - absl-py==1.4.0 + - astunparse==1.6.3 + - audioread==3.0.0 + - cachetools==5.3.0 + - cython==0.29.34 + - flatbuffers==2.0.7 + - fsspec==2023.5.0 + - gast==0.4.0 + - google-auth==2.18.1 + - google-auth-oauthlib==1.0.0 + - google-pasta==0.2.0 + - grpcio==1.54.2 + - h5py==3.8.0 + - hdbscan==0.8.29 + - hearbaseline==2021.1.1 + - huggingface-hub==0.14.1 + - hyperpyyaml==1.2.0 + - importlib-metadata==6.6.0 + - joblib==1.2.0 + - julius==0.2.7 + - keras==2.7.0 + - keras-preprocessing==1.1.2 + - libclang==16.0.0 + - librosa==0.9.1 + - llvmlite==0.31.0 + - markdown==3.4.3 + - nnaudio==0.3.2 + - numba==0.48.0 + - numpy==1.19.2 + - oauthlib==3.2.2 + - opt-einsum==3.3.0 + - packaging==23.1 + - platformdirs==3.5.1 + - pooch==1.7.0 + - protobuf==3.19.6 + - pyasn1==0.5.0 + - pyasn1-modules==0.3.0 + - pynndescent==0.5.10 + - regex==2023.5.5 + - requests-oauthlib==1.3.1 + - resampy==0.2.2 + - rsa==4.9 + - ruamel-yaml==0.17.26 + - ruamel-yaml-clib==0.2.7 + - scikit-learn==1.2.2 + - scipy==1.9.3 + - sentencepiece==0.1.99 + - soundfile==0.12.1 + - speechbrain==0.5.14 + - tensorboard==2.13.0 + - tensorboard-data-server==0.7.0 + - tensorflow==2.7.4 + - tensorflow-estimator==2.7.0 + - tensorflow-io-gcs-filesystem==0.32.0 + - termcolor==2.3.0 + - threadpoolctl==3.1.0 + - tokenizers==0.13.3 + - torchcrepe==0.0.19 + - torchopenl3==1.0.1 + - tqdm==4.65.0 + - transformers==4.29.2 + - umap-learn==0.5.3 + - werkzeug==2.3.4 + - wrapt==1.15.0 + - zipp==3.15.0 +prefix: /home/paul.best/miniconda3/envs/hear diff --git a/new_specie/compute_embeddings.py b/new_specie/compute_embeddings.py index b25717cd09b095360ea5bd010ea1645ec949ebae..b849716bd2d221cf588502068af5eadaef804204 100755 --- a/new_specie/compute_embeddings.py +++ b/new_specie/compute_embeddings.py @@ -17,7 +17,8 @@ parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signa args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) +#frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) +frontend = models.frontend_gibbon encoder = models.sparrow_encoder(args.bottleneck // (args.nMel//32 * 4), (args.nMel//32, 4)) decoder = models.sparrow_decoder(args.bottleneck, (args.nMel//32, 4)) model = torch.nn.Sequential(frontend, encoder, decoder).to(device) @@ -25,7 +26,7 @@ model = torch.nn.Sequential(frontend, encoder, decoder).to(device) df = pd.read_csv(args.detections) print('Computing AE projections...') -loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8) +loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur, channel=1), batch_size=16, shuffle=False, num_workers=8, prefetch_factor=8) with torch.no_grad(): encodings, idxs = [], [] for x, idx in tqdm(loader): diff --git a/new_specie/fetch_annot_from_pngs.py b/new_specie/fetch_annot_from_pngs.py index 0427b3b21cae40e1e0c3f70e887c8dbb6f5e63fb..09fce4652e96a506ad2eda6c6fd3a689bb37fec3 100755 --- a/new_specie/fetch_annot_from_pngs.py +++ b/new_specie/fetch_annot_from_pngs.py @@ -12,7 +12,7 @@ parser.add_argument('annot_folder', type=str, help='Name of the folder containin parser.add_argument("detections", type=str, help=".csv file with detections that were clustered (labels will be added to it)") args = parser.parse_args() -df = pd.read_csv(args.detections)) +df = pd.read_csv(args.detections) for label in os.listdir(args.annot_folder+'/'): for file in os.listdir(f'{args.annot_folder}/{label}/'): diff --git a/new_specie/filterbank.py b/new_specie/filterbank.py index 32e9780e198b55c7388d39ec6e386bd6c29c2213..34ecf14f9750b72dec34fa7f9574398216f044ac 100755 --- a/new_specie/filterbank.py +++ b/new_specie/filterbank.py @@ -123,6 +123,14 @@ class STFT(nn.Module): x = x.reshape((batchsize, channels, -1) + x.shape[2:]) #if channels > 1 else x.reshape((batchsize, -1) + x.shape[2:]) return x +class MedFilt(nn.Module): + """ + Withdraw the median of each frequency band + """ + def __init__(self): + super(MedFilt, self).__init__() + def forward(self, x): + return x - torch.quantile(x, 0.2, dim=-1, keepdim=True)[0] class TemporalBatchNorm(nn.Module): @@ -161,3 +169,4 @@ class Log1p(nn.Module): def extra_repr(self): return 'trainable={}'.format(repr(self.trainable)) + diff --git a/new_specie/models.py b/new_specie/models.py index 48d7cf84ec44fc40247731fe0c8b15cbbb54e559..f7dcc80573136e04eed5839b15caf38a11c030c7 100755 --- a/new_specie/models.py +++ b/new_specie/models.py @@ -21,7 +21,7 @@ frontend = lambda sr, nfft, sampleDur, n_mel : nn.Sequential( sparrow_encoder = lambda nfeat, shape : nn.Sequential( nn.Conv2d(1, 32, 3, stride=2, bias=False, padding=(1)), nn.BatchNorm2d(32), - nn.LeakyReLU(0.01), + nn.ReLU(True), nn.Conv2d(32, 64, 3, stride=2, bias=False, padding=1), nn.BatchNorm2d(64), nn.ReLU(True), @@ -31,7 +31,7 @@ sparrow_encoder = lambda nfeat, shape : nn.Sequential( nn.Conv2d(128, 256, 3, stride=2, bias=False, padding=1), nn.BatchNorm2d(256), nn.ReLU(True), - nn.Conv2d(256, nfeat, (3, 5), stride=2, padding=(1, 2)), + nn.Conv2d(256, nfeat, 3, stride=2, padding=1), u.Reshape(nfeat * shape[0] * shape[1]) ) @@ -72,11 +72,8 @@ sparrow_decoder = lambda nfeat, shape : nn.Sequential( nn.ReLU(True), nn.Upsample(scale_factor=2), - nn.Conv2d(32, 32, (3, 3), bias=False, padding=1), - nn.BatchNorm2d(32), - nn.ReLU(True), nn.Conv2d(32, 1, (3, 3), bias=False, padding=1), - nn.ReLU(True) + nn.BatchNorm2d(1), + nn.ReLU(True), + nn.Conv2d(1, 1, (3, 3), bias=False, padding=1), ) - - diff --git a/new_specie/requirements.txt b/new_specie/requirements.txt index 5a9896469de69c28f190fb8131c99b03c41de438..9b25a7728b5ec988c9d7c266b42198f9f1240d0c 100755 --- a/new_specie/requirements.txt +++ b/new_specie/requirements.txt @@ -1,12 +1,48 @@ +albumentations==1.3.0 +comet_ml==3.33.3 +coremltools==6.3.0 +Flask==2.3.2 +get_file==0.1.6 +GitPython==3.1.31 +GitPython==3.1.31 hdbscan==0.8.28 -matplotlib==3.5.1 -numpy==1.22.3 -pandas==1.4.1 -scipy==1.8.0 -sounddevice==0.4.5 +ipython==8.5.0 +matplotlib==3.6.0 +mss==9.0.1 +numpy==1.23.5 +onnx==1.14.0 +onnxruntime==1.15.0 +onnxsim==0.4.28 +opencv_python==4.7.0.72 +openvino==2022.3.0 +paddle==1.0.2 +pafy==0.5.5 +pandas==1.5.0 +Pillow==9.0.1 +Pillow==9.5.0 +plotly==5.13.0 +psutil==5.9.4 +pycocotools==2.0.6 +PyDrive==1.3.1 +PySoundFile==0.9.0.post1 +PyYAML==6.0 +PyYAML==6.0 +Requests==2.31.0 +scipy==1.9.1 +seaborn==0.12.2 +setuptools==67.6.1 +setuptools==52.0.0 +sounddevice==0.4.6 soundfile==0.11.0 -torch==1.11.0+cu113 -torchvision==0.12.0+cu113 -tqdm==4.64.0 -umap==0.1.1 +tensorflow==2.12.0 +tensorflowjs==4.6.0 +tensorrt==8.6.1 +tflite_runtime==2.12.0 +tflite_support==0.4.3 +thop==0.1.1.post2209072238 +torch==1.12.1+cu113 +torchvision==0.13.1+cu113 +tqdm==4.64.1 +tritonclient==2.33.0 umap_learn==0.5.3 +x2paddle==1.4.1 diff --git a/new_specie/sort_cluster.py b/new_specie/sort_cluster.py index 6ab1a705ffff57582999dfc6986d2c7f3105cd6d..2f36c31d77efa6cae9aab14d48b60b643693faf7 100755 --- a/new_specie/sort_cluster.py +++ b/new_specie/sort_cluster.py @@ -5,6 +5,7 @@ from tqdm import tqdm import matplotlib.pyplot as plt import os import torch, numpy as np, pandas as pd +from filterbank import STFT, MelFilter, MedFilt, Log1p import hdbscan import argparse import models @@ -22,17 +23,18 @@ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFo parser.add_argument('encodings', type=str, help='.npy file containing umap projections and their associated index in the detection.pkl table (built using compute_embeddings.py)') parser.add_argument('detections', type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed") #parser.add_argument('audio_folder', type=str, help='Path to the folder with complete audio files') -parser.add_argument("-audio_folder", type=str, default='', help="Folder from which to load sound files") +parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files") parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation") parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)") parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation") parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded") parser.add_argument('-min_cluster_size', type=int, default=10, help='Used for HDBSCAN clustering.') +parser.add_argument('-channel', type=int, default=0) parser.add_argument('-min_sample', type=int, default=5, help='Used for HDBSCAN clustering.') parser.add_argument('-eps', type=float, default=0.05, help='Used for HDBSCAN clustering.') args = parser.parse_args() -df = pd.read_csv(args.detections) +df = pd.read_csv(args.detections, index_col=0) encodings = np.load(args.encodings, allow_pickle=True).item() idxs, umap = encodings['idx'], encodings['umap'] df.loc[idxs, 'umap_x'] = umap[:,0] @@ -46,7 +48,14 @@ df.at[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size, cluster_selection_method='leaf').fit_predict(umap) df.cluster = df.cluster.astype(int) -frontend = models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel) +fs = 44100 +frontend = torch.nn.Sequential( + STFT(2048, 256), + MelFilter(fs, 2048, 96, 500, 4000), + Log1p(4), + MedFilt() +) + figscat = plt.figure(figsize=(10, 5)) plt.title(f'{args.encodings} {args.min_cluster_size} {args.min_sample} {args.eps}') @@ -73,7 +82,7 @@ class temp(): dur, fs = info.duration, info.samplerate start = int(np.clip(row.pos - args.sampleDur/2, 0, dur - args.sampleDur) * fs) sig, fs = sf.read(f'{args.audio_folder}/{row.filename}', start=start, stop=start + int(args.sampleDur*fs), always_2d=True) - sig = sig[:,0] + sig = sig[:, args.channel] if fs != args.SR: sig = signal.resample(sig, int(len(sig)/fs*args.SR)) spec = frontend(torch.Tensor(sig).view(1, -1).float()).detach().squeeze() @@ -88,7 +97,7 @@ class temp(): figscat.canvas.draw() # Play the audio if soundAvailable: - sd.play(sig*10, fs) + sd.play(sig, fs) mtemp = temp() cid = figscat.canvas.mpl_connect('button_press_event', mtemp.onclick) @@ -108,7 +117,7 @@ for c, grp in df.groupby('cluster'): loader = torch.utils.data.DataLoader(u.Dataset(grp.sample(min(len(grp), 200)), args.audio_folder, args.SR, args.sampleDur), batch_size=1, num_workers=8) with torch.no_grad(): for x, idx in tqdm(loader, leave=False, desc=str(c)): - plt.specgram(x.squeeze().numpy()) - plt.tight_layout() + plt.imshow(frontend(x).squeeze().numpy(), origin='lower', aspect='auto') + plt.subplots_adjust(top=1, bottom=0, left=0, right=1) plt.savefig(f'cluster_pngs/{c}/{idx.squeeze().item()}') plt.close() diff --git a/new_specie/utils.py b/new_specie/utils.py index 5c6644f6546275418b03cf9741321b4de758a0d6..09acb635fbd5b8f5b07d63910e52731cf8d5cb38 100755 --- a/new_specie/utils.py +++ b/new_specie/utils.py @@ -10,9 +10,9 @@ def collate_fn(batch): return dataloader.default_collate(batch) class Dataset(Dataset): - def __init__(self, df, audiopath, sr, sampleDur): + def __init__(self, df, audiopath, sr, sampleDur, channel=0): super(Dataset, self) - self.audiopath, self.df, self.sr, self.sampleDur = audiopath, df, sr, sampleDur + self.audiopath, self.df, self.sr, self.sampleDur, self.channel = audiopath, df, sr, sampleDur, channel def __len__(self): return len(self.df) @@ -24,9 +24,9 @@ class Dataset(Dataset): dur, fs = info.duration, info.samplerate start = int(np.clip(row.pos - self.sampleDur/2, 0, max(0, dur - self.sampleDur)) * fs) sig, fs = sf.read(self.audiopath+'/'+row.filename, start=start, stop=start + int(self.sampleDur*fs), always_2d=True) - sig = sig[:,0] - except: - print(f'Failed to load sound from row {row.name} with filename {row.filename}') + sig = sig[:, row.Channel -1 if 'Channel' in row else self.channel] + except Exception as e: + print(f'Failed to load sound from row {row.name} with filename {row.filename}', e) return None if len(sig) < self.sampleDur * fs: sig = np.concatenate([sig, np.zeros(int(self.sampleDur * fs) - len(sig))]) diff --git a/PCEN_pytorch.py b/paper_experiments/PCEN_pytorch.py similarity index 100% rename from PCEN_pytorch.py rename to paper_experiments/PCEN_pytorch.py diff --git a/bengalese_finch1/bengalese_finch1.csv b/paper_experiments/bengalese_finch1/bengalese_finch1.csv similarity index 100% rename from bengalese_finch1/bengalese_finch1.csv rename to paper_experiments/bengalese_finch1/bengalese_finch1.csv diff --git a/bengalese_finch1/citation.txt b/paper_experiments/bengalese_finch1/citation.txt similarity index 100% rename from bengalese_finch1/citation.txt rename to paper_experiments/bengalese_finch1/citation.txt diff --git a/bengalese_finch2/bengalese_finch2.csv b/paper_experiments/bengalese_finch2/bengalese_finch2.csv similarity index 100% rename from bengalese_finch2/bengalese_finch2.csv rename to paper_experiments/bengalese_finch2/bengalese_finch2.csv diff --git a/bengalese_finch2/citation.txt b/paper_experiments/bengalese_finch2/citation.txt similarity index 100% rename from bengalese_finch2/citation.txt rename to paper_experiments/bengalese_finch2/citation.txt diff --git a/black-headed_grosbeaks/black-headed_grosbeaks.csv b/paper_experiments/black-headed_grosbeaks/black-headed_grosbeaks.csv similarity index 100% rename from black-headed_grosbeaks/black-headed_grosbeaks.csv rename to paper_experiments/black-headed_grosbeaks/black-headed_grosbeaks.csv diff --git a/black-headed_grosbeaks/filelist.txt b/paper_experiments/black-headed_grosbeaks/filelist.txt similarity index 100% rename from black-headed_grosbeaks/filelist.txt rename to paper_experiments/black-headed_grosbeaks/filelist.txt diff --git a/black-headed_grosbeaks/source.txt b/paper_experiments/black-headed_grosbeaks/source.txt similarity index 100% rename from black-headed_grosbeaks/source.txt rename to paper_experiments/black-headed_grosbeaks/source.txt diff --git a/california_thrashers/california_thrashers.csv b/paper_experiments/california_thrashers/california_thrashers.csv similarity index 100% rename from california_thrashers/california_thrashers.csv rename to paper_experiments/california_thrashers/california_thrashers.csv diff --git a/california_thrashers/filelist.txt b/paper_experiments/california_thrashers/filelist.txt similarity index 100% rename from california_thrashers/filelist.txt rename to paper_experiments/california_thrashers/filelist.txt diff --git a/california_thrashers/source.txt b/paper_experiments/california_thrashers/source.txt similarity index 100% rename from california_thrashers/source.txt rename to paper_experiments/california_thrashers/source.txt diff --git a/cassin_vireo/cassin_vireo.csv b/paper_experiments/cassin_vireo/cassin_vireo.csv similarity index 100% rename from cassin_vireo/cassin_vireo.csv rename to paper_experiments/cassin_vireo/cassin_vireo.csv diff --git a/cassin_vireo/filelist.txt b/paper_experiments/cassin_vireo/filelist.txt similarity index 100% rename from cassin_vireo/filelist.txt rename to paper_experiments/cassin_vireo/filelist.txt diff --git a/cassin_vireo/source.txt b/paper_experiments/cassin_vireo/source.txt similarity index 100% rename from cassin_vireo/source.txt rename to paper_experiments/cassin_vireo/source.txt diff --git a/compute_embeddings.py b/paper_experiments/compute_embeddings.py similarity index 100% rename from compute_embeddings.py rename to paper_experiments/compute_embeddings.py diff --git a/dolphin/dolphin.csv b/paper_experiments/dolphin/dolphin.csv similarity index 100% rename from dolphin/dolphin.csv rename to paper_experiments/dolphin/dolphin.csv diff --git a/filterbank.py b/paper_experiments/filterbank.py similarity index 100% rename from filterbank.py rename to paper_experiments/filterbank.py diff --git a/good_species.txt b/paper_experiments/good_species.txt similarity index 100% rename from good_species.txt rename to paper_experiments/good_species.txt diff --git a/hdbscan_gridsearch.py b/paper_experiments/hdbscan_gridsearch.py similarity index 100% rename from hdbscan_gridsearch.py rename to paper_experiments/hdbscan_gridsearch.py diff --git a/humpback/humpback.csv b/paper_experiments/humpback/humpback.csv similarity index 100% rename from humpback/humpback.csv rename to paper_experiments/humpback/humpback.csv diff --git a/humpback2/extract_annot.py b/paper_experiments/humpback2/extract_annot.py similarity index 100% rename from humpback2/extract_annot.py rename to paper_experiments/humpback2/extract_annot.py diff --git a/humpback2/humpback2.csv b/paper_experiments/humpback2/humpback2.csv similarity index 100% rename from humpback2/humpback2.csv rename to paper_experiments/humpback2/humpback2.csv diff --git a/models.py b/paper_experiments/models.py similarity index 100% rename from models.py rename to paper_experiments/models.py diff --git a/plot_annot_distrib.py b/paper_experiments/plot_annot_distrib.py similarity index 100% rename from plot_annot_distrib.py rename to paper_experiments/plot_annot_distrib.py diff --git a/plot_clusters.py b/paper_experiments/plot_clusters.py similarity index 65% rename from plot_clusters.py rename to paper_experiments/plot_clusters.py index e71515f124b9a60d94eefea46c14738fa542e5e2..511bdfb7c43973160da9172bff7414cd5206c53c 100755 --- a/plot_clusters.py +++ b/paper_experiments/plot_clusters.py @@ -9,19 +9,22 @@ fig, ax = plt.subplots(nrows=len(species), figsize=(7, 10), dpi=200) for i, specie in enumerate(species): meta = models.meta[specie] frontend = models.frontend['pcenMel'](meta['sr'], meta['nfft'], meta['sampleDur'], 128) - dic = np.load(f'{specie}/encodings//encodings_{specie}_256_pcenMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item() - idxs, X = dic['idxs'], dic['umap'] + dic = np.load(f'{specie}/encodings//encodings_{specie}_256_logMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item() + idxs, X = dic['idxs'], dic['umap8'] df = pd.read_csv(f'{specie}/{specie}.csv') - clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.05, core_dist_n_jobs=-1, cluster_selection_method='leaf').fit_predict(X) + clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf' if not 'humpback' in specie else 'eom').fit_predict(X) df.loc[idxs, 'cluster'] = clusters.astype(int) for j, cluster in enumerate(np.random.choice(np.arange(df.cluster.max()), 4)): - for k, (x, name) in enumerate(torch.utils.data.DataLoader(u.Dataset(df[df.cluster==cluster].sample(8), f'{specie}/audio/', meta['sr'], meta['sampleDur']), batch_size=1)): + loader = torch.utils.data.DataLoader(u.Dataset(df[df.cluster==cluster].sample(8), f'{specie}/audio/', meta['sr'], meta['sampleDur']), batch_size=1) + for k, (x, name) in enumerate(loader): spec = frontend(x).squeeze().numpy() ax[i].imshow(spec, extent=[k, k+1, j, j+1], origin='lower', aspect='auto', cmap='Greys', vmin=np.quantile(spec, .2), vmax=np.quantile(spec, .98)) ax[i].set_xticks([]) ax[i].set_yticks([]) + if specie == 'humpback2': + specie = 'humpback\n(small)' # ax[i].grid(color='w', xdata=np.arange(1, 10), ydata=np.arange(1, 5)) - ax[i].set_ylabel(specie.replace('_', ' ')) + ax[i].set_ylabel(specie.replace('_', '\n')) ax[i].set_xlim(0, 8) ax[i].set_ylim(0, 4) ax[i].vlines(np.arange(1, 8), 0, 4, linewidths=1, color='black') diff --git a/plot_main_results.py b/paper_experiments/plot_main_results.py similarity index 50% rename from plot_main_results.py rename to paper_experiments/plot_main_results.py index 75de2a853666619347d35ae92f30b20c939fd857..b766250aed36dcf63d6414b900908843b6708ea8 100755 --- a/plot_main_results.py +++ b/paper_experiments/plot_main_results.py @@ -9,8 +9,10 @@ all_frontendNames = [['AE prcptl', 'AE MSE', 'Spectro.', 'PAFs'][::-1], ['AE', 'gen. AE', 'OpenL3', 'Wav2Vec2', 'CREPE'][::-1], ['log-Mel', 'Mel', 'PCEN-Mel', 'log-STFT']] +# Bar plots of NMI depending on the choice of feature extraction all_plotNames = ['handcraft', 'deepembed', 'frontends'] for frontends, frontendNames, plotName in zip(all_frontends, all_frontendNames, all_plotNames): + print(plotName) df = pd.read_csv('hdbscan_HP.csv') df.loc[df.ms.isna(), 'ms'] = 0 best = df.loc[df.groupby(["specie", 'frontend']).nmi.idxmax()] @@ -29,3 +31,46 @@ for frontends, frontendNames, plotName in zip(all_frontends, all_frontendNames, plt.grid(axis='y') plt.tight_layout() plt.savefig(f'NMIs_hdbscan_barplot_{plotName}.pdf') + + +# Plot of NMI depending on UMAP nb of components +fig, ax = plt.subplots(ncols=2, sharey=True, sharex=True, figsize=(10, 3.5)) +# Left panel with auto-encoder feature extraction +df = pd.read_csv("hdbscan_HP_archive2.csv") +df.loc[df.ms.isna(), "ms"] = 0 +df = df[((df.frontend == "256_logMel128")&(df.al=='leaf')&(df.mcs==10)&(df.ms==3)&(df.eps==.1))].sort_values('specie') +#df = df[df.frontend == "256_logMel128"].sort_values('specie') +for s, grp in df.groupby("specie"): + ax[0].plot( + np.arange(5), + [grp[grp.ncomp == n].nmi.max() / grp.nmi.max() for n in 2 ** np.arange(1, 6)], + ) +# Right panel with Spectrogram feature extraction +df = pd.read_csv("hdbscan_HP_archive2.csv") +df.loc[df.ms.isna(), "ms"] = 0 +df = df[((df.frontend == "spec32")&(df.al=='leaf')&(df.mcs==10)&(df.ms==3)&(df.eps==.1))].sort_values('specie') +#df = df[df.frontend == "spec32"].sort_values('specie') +for s, grp in df.groupby("specie"): + if s.endswith('s'): + s = s[:-1] + if s == 'humpback2': + s = 'humpback (small)' + ax[1].plot( + np.arange(5), + [grp[grp.ncomp == n].nmi.max() / grp.nmi.max() for n in 2 ** np.arange(1, 6)], + label=s.replace('_',' '), + ) +plt.legend() +ax[0].set_ylabel("NMI / max(NMI)") +ax[1].set_ylim(0.85, 1.01) +ax[0].grid() +ax[1].grid() +ax[0].set_title("auto-encoder") +ax[1].set_title("spectrogram") +ax[0].set_xlabel("# UMAP dimensions") +ax[1].set_xlabel("# UMAP dimensions") +ax[0].set_xticks(np.arange(5)) +ax[0].set_xticklabels(2 ** np.arange(1, 6)) +ax[0].set_yticks(np.arange(0.85, 1.01, 0.05)) +plt.tight_layout() +plt.savefig('NMI_umap.pdf') diff --git a/plot_prec_rec.py b/paper_experiments/plot_prec_rec.py similarity index 95% rename from plot_prec_rec.py rename to paper_experiments/plot_prec_rec.py index 22f38f797ae156331408c01fbb871f29c3e9d6f9..7d4615488947af78e66243f22a29053f9f9dc825 100755 --- a/plot_prec_rec.py +++ b/paper_experiments/plot_prec_rec.py @@ -22,12 +22,12 @@ for specie in species: dic = np.load(f'{specie}/encodings//encodings_{specie}_256_logMel128_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item() idxs, X = dic['idxs'], dic['umap8'] df = pd.read_csv(f'{specie}/{specie}.csv') - clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf' if not 'humpbacjjk' in specie else 'eom').fit_predict(X) + clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf' if not 'humpback' in specie else 'eom').fit_predict(X) df.loc[idxs, 'cluster'] = clusters.astype(int) dic = np.load(f'{specie}/encodings/encodings_spec32.npy', allow_pickle=True).item() idxs, X = dic['idxs'], dic['umap8'] - clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf' if not 'humpbacjjk' in specie else 'eom').fit_predict(X) + clusters = hdbscan.HDBSCAN(min_cluster_size=10, min_samples=3, cluster_selection_epsilon=0.1, core_dist_n_jobs=-1, cluster_selection_method='leaf' if not 'humpback' in specie else 'eom').fit_predict(X) df.loc[idxs, 'cluster2'] = clusters.astype(int) mask = ~df.loc[idxs].label.isna() diff --git a/plot_projections.py b/paper_experiments/plot_projections.py similarity index 100% rename from plot_projections.py rename to paper_experiments/plot_projections.py diff --git a/run_baseline.py b/paper_experiments/run_PAF_baseline.py similarity index 100% rename from run_baseline.py rename to paper_experiments/run_PAF_baseline.py diff --git a/run_hearbaseline.py b/paper_experiments/run_hearbaseline.py similarity index 100% rename from run_hearbaseline.py rename to paper_experiments/run_hearbaseline.py diff --git a/run_spec32_baseline.py b/paper_experiments/run_spec32_baseline.py similarity index 100% rename from run_spec32_baseline.py rename to paper_experiments/run_spec32_baseline.py diff --git a/test_AE.py b/paper_experiments/test_AE.py similarity index 100% rename from test_AE.py rename to paper_experiments/test_AE.py diff --git a/test_AE_all.py b/paper_experiments/test_AE_all.py similarity index 100% rename from test_AE_all.py rename to paper_experiments/test_AE_all.py diff --git a/train_AE.py b/paper_experiments/train_AE.py similarity index 100% rename from train_AE.py rename to paper_experiments/train_AE.py diff --git a/train_AE_all.py b/paper_experiments/train_AE_all.py similarity index 100% rename from train_AE_all.py rename to paper_experiments/train_AE_all.py diff --git a/utils.py b/paper_experiments/utils.py similarity index 100% rename from utils.py rename to paper_experiments/utils.py diff --git a/plot_results_kmeans.py b/plot_results_kmeans.py deleted file mode 100755 index a60bdb761eb54e3de15ab9d1465d9219716a8ac9..0000000000000000000000000000000000000000 --- a/plot_results_kmeans.py +++ /dev/null @@ -1,38 +0,0 @@ -import hdbscan -import pandas as pd -import matplotlib.pyplot as plt -import numpy as np -from sklearn import metrics, cluster -from scipy.stats import linregress - -species = np.loadtxt('good_species.txt', dtype=str) -frontends = ['16_pcenMel128', '16_logMel128', '16_logSTFT', '16_Mel128', '8_pcenMel64', '32_pcenMel128'] -plt.figure() -for specie in species: - df = pd.read_csv(f'{specie}/{specie}.csv') - nmis = [] - for i, frontend in enumerate(frontends): - print(specie, frontend) - dic = np.load(f'{specie}/encodings_{specie}_{frontend}_sparrow_encoder_decod2_BN_nomaxPool.npy', allow_pickle=True).item() - idxs, encodings, X = dic['idxs'], dic['encodings'], dic['umap'] - - ks = (5*1.2**np.arange(20)).astype(int) - distorsions = [cluster.KMeans(n_clusters=k).fit(encodings).inertia_ for k in ks] - errors = [linregress(ks[:i], distorsions[:i]).stderr + linregress(ks[i+1:], distorsions[i+1:]).stderr for i in range(2, len(ks)-2)] - k = ks[np.argmin(errors)] - clusters = cluster.KMeans(n_clusters=k).fit_predict(encodings) - df.loc[idxs, 'cluster'] = clusters.astype(int) - - mask = ~df.loc[idxs].label.isna() - clusters, labels = clusters[mask], df.loc[idxs[mask]].label - nmis.append(metrics.normalized_mutual_info_score(labels, clusters)) - df.drop('cluster', axis=1, inplace=True) - plt.scatter(nmis, np.arange(len(frontends)), label=specie) - -plt.yticks(range(len(frontends)), frontends) -plt.ylabel('archi') -plt.xlabel('NMI with expert labels') -plt.grid() -plt.legend() -plt.tight_layout() -plt.savefig('NMIs_kmeans.pdf') diff --git a/print_annot.py b/print_annot.py deleted file mode 100755 index da77b06dc33611633b0de7e461b9dac729effc71..0000000000000000000000000000000000000000 --- a/print_annot.py +++ /dev/null @@ -1,35 +0,0 @@ -import os, shutil, argparse -from tqdm import tqdm -import matplotlib.pyplot as plt -import pandas as pd, numpy as np -import models, utils as u -import torch - -parser = argparse.ArgumentParser() -parser.add_argument("specie", type=str) -parser.add_argument("-frontend", type=str, default='logMel') -parser.add_argument("-nMel", type=int, default=128) -args = parser.parse_args() - -meta = models.meta[args.specie] -df = pd.read_csv(f'{args.specie}/{args.specie}.csv') -frontend = models.frontend[args.frontend](meta['sr'], meta['nfft'], meta['sampleDur'], args.nMel) -shutil.rmtree(f'{args.specie}/annot_pngs', ignore_errors=True) - -for label, grp in df.groupby('label'): - os.makedirs(f'{args.specie}/annot_pngs/{label}', exist_ok=True) - loader = torch.utils.data.DataLoader(u.Dataset(grp, args.specie+'/audio/', meta['sr'], meta['sampleDur']),\ - batch_size=1, num_workers=4, pin_memory=True) - for x, idx in tqdm(loader, desc=args.specie + ' ' + label, leave=False): - x = frontend(x).squeeze().detach() - assert not torch.isnan(x).any(), "Found a NaN in spectrogram... :/" - plt.figure() - plt.imshow(x, origin='lower', aspect='auto', cmap='Greys', vmin=np.quantile(x, .7)) - plt.subplots_adjust(top=1, bottom=0, left=0, right=1) - row = df.loc[idx.item()] - #plt.savefig(f'{args.specie}/annot_pngs/{label}/{row.fn.split(".")[0]}_{row.pos:.2f}.png') - plt.savefig(f'{args.specie}/annot_pngs/{label}/{idx.item()}') - plt.close() - - - diff --git a/print_reconstr.py b/print_reconstr.py deleted file mode 100755 index 66efcbf29cf6e88038362fd63b9ed3efd7fb1068..0000000000000000000000000000000000000000 --- a/print_reconstr.py +++ /dev/null @@ -1,48 +0,0 @@ -import os, shutil, argparse -from tqdm import tqdm -import matplotlib.pyplot as plt -import pandas as pd, numpy as np -import models, utils as u -import torch - -parser = argparse.ArgumentParser() -parser.add_argument("specie", type=str) -parser.add_argument("-bottleneck", type=int, default=16) -parser.add_argument("-frontend", type=str, default='logMel') -parser.add_argument("-encoder", type=str, default='sparrow_encoder') -parser.add_argument("-prcptl", type=int, default=1) -parser.add_argument("-nMel", type=int, default=128) -args = parser.parse_args() - -modelname = f'{args.specie}_{args.bottleneck}_{args.frontend}{args.nMel if "Mel" in args.frontend else ""}_{args.encoder}_decod2_BN_nomaxPool{"_noprcptl" if args.prcptl==0 else ""}.stdc' -print(modelname) -gpu = torch.device(f'cuda') - -meta = models.meta[args.specie] -frontend = models.frontend[args.frontend](meta['sr'], meta['nfft'], meta['sampleDur'], args.nMel) -encoder = models.__dict__[args.encoder](*((args.bottleneck // 16, (4, 4)) if args.nMel == 128 else (args.bottleneck // 8, (2, 4)))) -decoder = models.sparrow_decoder(args.bottleneck, (4, 4) if args.nMel == 128 else (2, 4)) -model = torch.nn.Sequential(frontend, encoder, decoder).to(gpu) -model.load_state_dict(torch.load(f'{args.specie}/weights/{modelname}')) -model.eval() - -df = pd.read_csv(f'{args.specie}/{args.specie}.csv') - -shutil.rmtree(f'{args.specie}/reconstruct_pngs', ignore_errors=True) - -for label, grp in df.groupby('label'): - os.makedirs(f'{args.specie}/reconstruct_pngs/{label}', exist_ok=True) - loader = torch.utils.data.DataLoader(u.Dataset(grp, args.specie+'/audio/', meta['sr'], meta['sampleDur']),\ - batch_size=1, num_workers=4, pin_memory=True) - with torch.no_grad(): - for x, idx in tqdm(loader, desc=args.specie + ' ' + label, leave=False): - x = model(x.to(gpu)).squeeze().detach().cpu() - assert not torch.isnan(x).any(), "Found a NaN in spectrogram... :/" - plt.imshow(x, origin='lower', aspect='auto') #, cmap='Greys', vmin=np.quantile(x, .7)) - plt.subplots_adjust(top=1, bottom=0, left=0, right=1) - row = df.loc[idx.item()] - plt.savefig(f'{args.specie}/reconstruct_pngs/{label}/{idx.item()}') - plt.close() - - - diff --git a/sort_cluster.py b/sort_cluster.py deleted file mode 100755 index ac8d92e854dff183832deb875019cdb0b5eee7b7..0000000000000000000000000000000000000000 --- a/sort_cluster.py +++ /dev/null @@ -1,96 +0,0 @@ -import utils as u -from tqdm import tqdm -import torch -import matplotlib.pyplot as plt -import soundfile as sf -import models -import os -import numpy as np -import pandas as pd -import hdbscan -import argparse -try: - import sounddevice as sd - soundAvailable = True -except: - soundAvailable = False - -parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='blabla') -parser.add_argument('encodings', type=str) -parser.add_argument('--min_cluster_size', type=int, default=50) -parser.add_argument('--min_sample', type=int, default=10) -parser.add_argument('--eps', type=float, default=1) -args = parser.parse_args() - -gpu = torch.device('cuda:0') -frontend = models.get['frontend_logMel'].to(gpu) - -a = np.load(args.encodings, allow_pickle=True).item() -df = pd.read_pickle('detections.pkl') -idxs, umap = a['idx'], a['umap'] - -# cluster the embedings (min_cluster_size and min_samples parameters need to be tuned) -df.loc[idxs, 'cluster'] = hdbscan.HDBSCAN(min_cluster_size=args.min_cluster_size, - min_samples=args.min_sample, - core_dist_n_jobs=-1, -# cluster_selection_epsilon=args.eps, - cluster_selection_method='leaf').fit_predict(umap) -df.loc[idxs, ['umap_x', 'umap_y']] = umap -df.cluster = df.cluster.astype(int) - -figscat = plt.figure(figsize=(20, 10)) -plt.title(f'{args.encodings} {args.min_cluster_size} {args.min_sample} {args.eps}') -for c, grp in df.groupby('cluster'): - plt.scatter(grp.umap_x, grp.umap_y, s=3, alpha=.1, c='grey' if c == -1 else None) -#plt.scatter(df[((df.type!='Vocalization')&(~df.type.isna()))].umap_x, df[((df.type!='Vocalization')&(~df.type.isna()))].umap_y, marker='x') -plt.tight_layout() -axScat = figscat.axes[0] -plt.savefig('projection') -figSpec = plt.figure() -plt.scatter(0, 0) -axSpec = figSpec.axes[0] - -print(df.groupby('cluster').count()) - -class temp(): - def __init__(self): - self.row = "" - def onclick(self, event): - #get row - left, right, bottom, top = axScat.get_xlim()[0], axScat.get_xlim()[1], axScat.get_ylim()[0], axScat.get_ylim()[1] - rangex, rangey = right - left, top - bottom - closest = (np.sqrt(((df.umap_x - event.xdata)/rangex)**2 + ((df.umap_y - event.ydata)/rangey)**2)).idxmin() - sig, fs = sf.read(f'/data_ssd/marmossets/{closest}.wav') - spec = frontend(torch.Tensor(sig).to(gpu).view(1, -1).float()).detach().cpu().squeeze() - axSpec.imshow(spec, origin='lower', aspect='auto') - row = df.loc[closest] - axSpec.set_title(f'{closest}, cluster {row.cluster} ({(df.cluster==row.cluster).sum()} points)') - axScat.scatter(row.umap_x, row.umap_y, c='r') - axScat.set_xlim(left, right) - axScat.set_ylim(bottom, top) - figSpec.canvas.draw() - figscat.canvas.draw() - if soundAvailable: - sd.play(sig*10, fs) - -mtemp = temp() - -cid = figscat.canvas.mpl_connect('button_press_event', mtemp.onclick) - -plt.show() - -if input('print cluster pngs ??') != 'y': - exit() - -os.system('rm -R cluster_pngs/*') -for c, grp in df.groupby('cluster'): - if c == -1 or len(grp) > 10_000: - continue - os.system('mkdir cluster_pngs/'+str(c)) - with torch.no_grad(): - for x, idx in tqdm(torch.utils.data.DataLoader(u.Dataset(grp.sample(200), sampleDur=.5), batch_size=1, num_workers=8), leave=False, desc=str(c)): - x = x.to(gpu) - x = frontend(x).cpu().detach().squeeze() - plt.imshow(x, origin='lower', aspect='auto') - plt.savefig(f'cluster_pngs/{c}/{idx.squeeze().item()}') - plt.close()