From 60d4559283f8ae7bafcf7ecffa1a048593044453 Mon Sep 17 00:00:00 2001 From: "paul.best" <paul.best@lis-lab.fr> Date: Tue, 27 Aug 2024 17:21:07 +0200 Subject: [PATCH] last script update --- .gitignore | 0 eval_all.py | 4 ++-- get_SNR.py | 23 ++++++++++++++++---- metadata.py | 8 +++---- predict.py | 47 ++++++++++++++++++++++++++++++++++++++++ predict_requirements.txt | 7 ++++++ run_all.py | 15 ++++++++++--- train_crepe.py | 2 -- 8 files changed, 91 insertions(+), 15 deletions(-) mode change 100644 => 100755 .gitignore create mode 100644 predict.py create mode 100644 predict_requirements.txt diff --git a/.gitignore b/.gitignore old mode 100644 new mode 100755 diff --git a/eval_all.py b/eval_all.py index fadf8d1..94e24fd 100644 --- a/eval_all.py +++ b/eval_all.py @@ -15,7 +15,7 @@ parser.add_argument('specie', type=str, help="Species to treat specifically", de args = parser.parse_args() for specie in species if args.specie=='all' else args.specie.split(' '): - algos = {'pyin', 'praat', 'crepe', 'tcrepe', 'tcrepe_ftsp', 'tcrepe_ftoth', 'basic', 'pesto', 'pesto_ft', 'pesto_ftoth'} + algos = {'pyin', 'praat', 'crepe', 'tcrepe', 'tcrepe_ftsp', 'tcrepe_ftspV', 'tcrepe_ftoth', 'basic', 'pesto', 'pesto_ft', 'pesto_ftoth'} # Get optimal thresholds confs = {k:[] for k in algos} confs['label'] = [] @@ -49,7 +49,7 @@ for specie in species if args.specie=='all' else args.specie.split(' '): out.loc[algo, 'Pitch acc'] = mir_eval.melody.raw_pitch_accuracy(df.annot>0, df.annot, df[algo+'_conf'], df[algo+'_f0'], cent_tolerance=cent_thr) out.loc[algo, 'Chroma acc'] = mir_eval.melody.raw_chroma_accuracy(df.annot>0, df.annot, df[algo+'_conf'], df[algo+'_f0'], cent_tolerance=cent_thr) out.at[algo, 'diff_distrib'] = list(abs(df[algo+'_f0'] - df.annot)) - out.loc[algo, 'Voc. recall'] = ((df.annot > 0 ) & ( df[algo+'_conf'] > thrs[algo])).sum() > 0.5 * (df.annot > 0).sum() + out.loc[algo, 'Voc. recall'] = ((df.annot > 0 ) & ( df[algo+'_conf'] > thrs[algo])).sum() > 0.33 * (df.annot > 0).sum() return out df = pd.concat(p_umap(fun, glob(species[specie]['wavpath'][:-4]+'_preds.csv'), desc=f'{specie} get perf')) diff --git a/get_SNR.py b/get_SNR.py index 2f0c573..af82951 100644 --- a/get_SNR.py +++ b/get_SNR.py @@ -1,18 +1,33 @@ import pandas as pd, numpy as np +from scipy import signal from p_tqdm import p_umap from metadata import species from glob import glob import librosa - def fun(fn): - sig, fs = librosa.load(fn) df = pd.read_csv(fn[:-4]+'_preds.csv') + sig, fs = librosa.load(fn) + + if df[df.annot > 0].annot.min() >= fs / 2: + return fn, None + + sos = signal.butter(3, df[df.annot>0].annot.min() * 2 / fs, 'highpass', output='sos') + sig = signal.sosfiltfilt(sos, sig) + start = df[df.annot > 0].time.min() end = df[df.annot > 0].time.max() S = np.std(sig[int(start*fs):int(end*fs)]) - N = np.std(np.concatenate([sig[:int(start*fs)], sig[int(end*fs):]])) - return fn, 10 * np.log10(S/N) + + if end - start < len(sig)/fs/2: # if the voiced section is smaller than half the signal duration, we estimate the noise over the same duration only + N = np.std(np.concatenate([ sig[ int((start-(end-start)/2)*fs) : int(start*fs) ], sig[ int(end*fs) : int((end + (end - start)/2)*fs)] ])) + else: + N = np.std(np.concatenate([sig[:int(start*fs)], sig[int(end*fs):]])) + + if S < N: + return fn, None + else: + return fn, 10 * np.log10(S/N -1) for specie in species: ret = p_umap(fun, glob(species[specie]['wavpath']), desc=specie) diff --git a/metadata.py b/metadata.py index dfef6f0..ae26259 100644 --- a/metadata.py +++ b/metadata.py @@ -1,5 +1,5 @@ species = { - 'wolves':{ + 'canids':{ 'wavpath': 'data/wolves/*/*.wav', 'FS': 16000, 'nfft': 1024, @@ -48,7 +48,7 @@ species = { 'downsample':3, 'step': 1/16 }, - 'la_Palma_chaffinches':{ + 'La_Palma_chaffinches':{ 'wavpath': 'data/FCPalmae/cut/*.wav', 'FS':44100, 'nfft':1024, @@ -62,14 +62,14 @@ species = { 'downsample':1, 'step': 1/8 }, - 'Reunion_white_eyes':{ + 'Reunion_grey_white_eyes':{ 'wavpath': 'data/white_eye/cut/*.wav', 'FS':44100, 'nfft':1024, 'downsample':5, 'step': 1/16 }, - 'long_billed_hermits':{ + 'long-billed_hermits':{ 'wavpath':'data/marcelo/long_billed_hermit_songs/*.wav', 'FS':44100, 'nfft':512, diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..6d3207c --- /dev/null +++ b/predict.py @@ -0,0 +1,47 @@ +import matplotlib.pyplot as plt +import argparse, os, tqdm +import torchcrepe, torch, librosa +import pandas as pd, numpy as np + +parser = argparse.ArgumentParser() +parser.add_argument('indir', type=str, help="Directory with sound files to process") +parser.add_argument('--model_path', type=str, help="Path of model weights", default='crepe_ft/model_all.pth') +parser.add_argument('--compress', type=float, help="Compression factor used to shift frequencies into CREPE's range [32Hz; 2kHz]. \ + Frequencies are divided by the given factor by artificially changing the sampling rate (slowing down / speeding up the signal).", default=1) +parser.add_argument('--step', type=float, help="Step used between each prediction (in seconds)", default=256 / torchcrepe.SAMPLE_RATE) +parser.add_argument('--decoder', choices=['argmax', 'weighted_argmax', 'viterbi'], help="Decoder used to postprocess predictions", default='viterbi') +parser.add_argument('--print', type=bool, help="Print spectrograms with overlaid F0 predictions to assess their quality", default=False) +args = parser.parse_args() + +# Initialisations +device, batch_size = ('cuda', 64) if torch.cuda.is_available() else ('cpu', 1) +model = torchcrepe.Crepe('full').eval().to(device) +model.load_state_dict(torch.load(args.model_path, map_location=device)) +decoder = torchcrepe.decode.__dict__[args.decoder] + +for filename in tqdm.tqdm(os.listdir(args.indir)): + try: + sig, fs = librosa.load(os.path.join(args.indir, filename), sr=int(torchcrepe.SAMPLE_RATE * args.compress)) + except: + print(f'Failed to load {filename}') + continue + + generator = torchcrepe.core.preprocess(torch.tensor(sig).unsqueeze(0), torchcrepe.SAMPLE_RATE, \ + hop_length=int(args.step / args.compress * torchcrepe.SAMPLE_RATE), batch_size=batch_size, device=device) + with torch.inference_mode(): + preds = torch.vstack([model(frames).cpu() for frames in generator]).T.unsqueeze(0) + f0 = (torchcrepe.core.postprocess(preds, decoder=decoder) * args.compress).squeeze() + confidence = preds.max(axis=1)[0].squeeze() + time = np.arange(0, len(sig)/fs, args.step) + + df = pd.DataFrame({'time':time, 'f0':f0, 'confidence':confidence}) + df.to_csv(os.path.join(args.indir, filename.rsplit('.',1)[0]+'_f0.csv'), index=False) + # Plot F0 predictions over spectrograms + if args.print: + plt.figure(figsize=(max(6.4, 6.4*time[-1]/2), 4.8)) + plt.specgram(sig, Fs=fs) + plt.scatter(time, f0, c=confidence) + plt.colorbar() + plt.tight_layout() + plt.savefig(os.path.join(args.indir, filename.rsplit('.',1)[0])+'_f0.png') + plt.close() \ No newline at end of file diff --git a/predict_requirements.txt b/predict_requirements.txt new file mode 100644 index 0000000..72ae954 --- /dev/null +++ b/predict_requirements.txt @@ -0,0 +1,7 @@ +librosa==0.10.1 +matplotlib==3.9.2 +numpy==1.23.5 +pandas==1.5.2 +torch==1.13.1+cu117 +torchcrepe==0.0.22 +tqdm==4.64.1 diff --git a/run_all.py b/run_all.py index 4573c66..260ce8f 100644 --- a/run_all.py +++ b/run_all.py @@ -15,12 +15,15 @@ tcrepe_model.load_state_dict(torch.load('/home/paul.best/.local/lib/python3.9/si cents_mapping = np.linspace(0, 7180, 360) + 1997.3794084376191 -def run_tcrepe(model, sig, fs, dt): +def run_tcrepe(model, sig, fs, dt, viterbi=False): generator = torchcrepe.core.preprocess(torch.tensor(sig).unsqueeze(0), fs, hop_length=dt*fs if fs != torchcrepe.SAMPLE_RATE else int(dt*fs),\ batch_size=batch_size, device=device, pad=False) with torch.no_grad(): preds = np.vstack([model(frames).cpu().numpy() for frames in generator]) - f0 = 10 * 2 ** (crepe.core.to_local_average_cents(preds) / 1200) + if viterbi: + f0 = 10 * 2 ** (crepe.core.to_viterbi_cents(preds) / 1200) + else: + f0 = 10 * 2 ** (crepe.core.to_local_average_cents(preds) / 1200) confidence = np.max(preds, axis=1) time = np.arange(torchcrepe.WINDOW_SIZE/2, len(sig)/fs*torchcrepe.SAMPLE_RATE - torchcrepe.WINDOW_SIZE/2 + 1e-9, dt*torchcrepe.SAMPLE_RATE) / torchcrepe.SAMPLE_RATE return time, f0, confidence @@ -34,7 +37,7 @@ parser.add_argument('--split', type=int, help="Section to test on between 0 and args = parser.parse_args() algos = ['praat_f0','pyin_f0','crepe_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0', 'pesto_ftoth_f0'] -quick_algos = ['praat_f0','pyin_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0'] +quick_algos = ['praat_f0','pyin_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftspV_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0'] if args.overwrite: print('Overwriting previous results') @@ -107,6 +110,12 @@ for specie in species if args.specie =='all' else args.specie.split(' '): out.loc[mask, 'tcrepe_ftsp_f0'], out.loc[mask, 'tcrepe_ftsp_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out[mask].time) out.tcrepe_ftsp_f0 *= downsample + if not 'tcrepe_ftspV_f0' in out.columns and tcrepe_ftsp_model: # torch crepe finetuned on the target species + time, f0, confidence = run_tcrepe(tcrepe_ftsp_model, sig, fs, dt, viterbi=True) + mask = ((out.time > time[0])&(out.time < time[-1])) + out.loc[mask, 'tcrepe_ftspV_f0'], out.loc[mask, 'tcrepe_ftspV_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out[mask].time) + out.tcrepe_ftspV_f0 *= downsample + if not 'tcrepe_ftoth_f0' in out.columns and tcrepe_ftoth_model: # torch crepe finetuned on other species than the target time, f0, confidence = run_tcrepe(tcrepe_ftoth_model, sig, fs, dt) mask = ((out.time > time[0])&(out.time < time[-1])) diff --git a/train_crepe.py b/train_crepe.py index 8af7ee8..d1e1c64 100644 --- a/train_crepe.py +++ b/train_crepe.py @@ -30,8 +30,6 @@ if not os.path.isfile(f'crepe_ft/train_set_{suffix}.pkl'): if args.only: files = files[:int(len(files)/5*args.split)] + files[int(len(files)/5*(args.split+1)):] for fn in tqdm.tqdm(pd.Series(files).sample(min(len(files), 1000)), desc='Peparing dataset for '+specie): - if os.path.isfile(f'noisy_pngs/{fn[:-4]}.png'): - continue annot = pd.read_csv(fn[:-4]+'.csv').drop_duplicates(subset='Time') sig, fs = librosa.load(fn, sr=None) sig = resampy.resample(sig, fs//downsample, FS) -- GitLab