From 60d4559283f8ae7bafcf7ecffa1a048593044453 Mon Sep 17 00:00:00 2001
From: "paul.best" <paul.best@lis-lab.fr>
Date: Tue, 27 Aug 2024 17:21:07 +0200
Subject: [PATCH] last script update

---
 .gitignore               |  0
 eval_all.py              |  4 ++--
 get_SNR.py               | 23 ++++++++++++++++----
 metadata.py              |  8 +++----
 predict.py               | 47 ++++++++++++++++++++++++++++++++++++++++
 predict_requirements.txt |  7 ++++++
 run_all.py               | 15 ++++++++++---
 train_crepe.py           |  2 --
 8 files changed, 91 insertions(+), 15 deletions(-)
 mode change 100644 => 100755 .gitignore
 create mode 100644 predict.py
 create mode 100644 predict_requirements.txt

diff --git a/.gitignore b/.gitignore
old mode 100644
new mode 100755
diff --git a/eval_all.py b/eval_all.py
index fadf8d1..94e24fd 100644
--- a/eval_all.py
+++ b/eval_all.py
@@ -15,7 +15,7 @@ parser.add_argument('specie', type=str, help="Species to treat specifically", de
 args = parser.parse_args()
 
 for specie in species if args.specie=='all' else args.specie.split(' '):
-    algos = {'pyin', 'praat', 'crepe', 'tcrepe', 'tcrepe_ftsp', 'tcrepe_ftoth', 'basic', 'pesto', 'pesto_ft', 'pesto_ftoth'}
+    algos = {'pyin', 'praat', 'crepe', 'tcrepe', 'tcrepe_ftsp', 'tcrepe_ftspV', 'tcrepe_ftoth', 'basic', 'pesto', 'pesto_ft', 'pesto_ftoth'}
     # Get optimal thresholds
     confs = {k:[] for k in algos}
     confs['label'] = []
@@ -49,7 +49,7 @@ for specie in species if args.specie=='all' else args.specie.split(' '):
             out.loc[algo, 'Pitch acc'] = mir_eval.melody.raw_pitch_accuracy(df.annot>0, df.annot, df[algo+'_conf'], df[algo+'_f0'], cent_tolerance=cent_thr)
             out.loc[algo, 'Chroma acc'] = mir_eval.melody.raw_chroma_accuracy(df.annot>0, df.annot, df[algo+'_conf'], df[algo+'_f0'], cent_tolerance=cent_thr)
             out.at[algo, 'diff_distrib'] = list(abs(df[algo+'_f0'] - df.annot))
-            out.loc[algo, 'Voc. recall'] = ((df.annot > 0 ) & ( df[algo+'_conf'] > thrs[algo])).sum() > 0.5 * (df.annot > 0).sum()
+            out.loc[algo, 'Voc. recall'] = ((df.annot > 0 ) & ( df[algo+'_conf'] > thrs[algo])).sum() > 0.33 * (df.annot > 0).sum()
         return out
 
     df = pd.concat(p_umap(fun, glob(species[specie]['wavpath'][:-4]+'_preds.csv'), desc=f'{specie} get perf'))
diff --git a/get_SNR.py b/get_SNR.py
index 2f0c573..af82951 100644
--- a/get_SNR.py
+++ b/get_SNR.py
@@ -1,18 +1,33 @@
 import pandas as pd, numpy as np
+from scipy import signal
 from p_tqdm import p_umap
 from metadata import species
 from glob import glob
 import librosa
 
-
 def fun(fn):
-    sig, fs = librosa.load(fn)
     df = pd.read_csv(fn[:-4]+'_preds.csv')
+    sig, fs = librosa.load(fn)
+
+    if df[df.annot > 0].annot.min() >= fs / 2:
+        return fn, None
+
+    sos = signal.butter(3, df[df.annot>0].annot.min() * 2 / fs, 'highpass', output='sos')
+    sig = signal.sosfiltfilt(sos, sig)
+
     start = df[df.annot > 0].time.min()
     end = df[df.annot > 0].time.max()
     S = np.std(sig[int(start*fs):int(end*fs)])
-    N = np.std(np.concatenate([sig[:int(start*fs)], sig[int(end*fs):]]))
-    return fn, 10 * np.log10(S/N)
+
+    if end - start < len(sig)/fs/2: # if the voiced section is smaller than half the signal duration, we estimate the noise over the same duration only
+        N = np.std(np.concatenate([ sig[ int((start-(end-start)/2)*fs) : int(start*fs) ], sig[ int(end*fs) : int((end + (end - start)/2)*fs)] ]))
+    else:
+        N = np.std(np.concatenate([sig[:int(start*fs)], sig[int(end*fs):]]))
+
+    if S < N:
+        return fn, None
+    else:
+        return fn, 10 * np.log10(S/N -1)
 
 for specie in species:
     ret = p_umap(fun, glob(species[specie]['wavpath']), desc=specie)
diff --git a/metadata.py b/metadata.py
index dfef6f0..ae26259 100644
--- a/metadata.py
+++ b/metadata.py
@@ -1,5 +1,5 @@
 species = {
-    'wolves':{
+    'canids':{
         'wavpath': 'data/wolves/*/*.wav',
         'FS': 16000,
         'nfft': 1024,
@@ -48,7 +48,7 @@ species = {
         'downsample':3,
         'step': 1/16
     },
-    'la_Palma_chaffinches':{
+    'La_Palma_chaffinches':{
         'wavpath': 'data/FCPalmae/cut/*.wav',
         'FS':44100,
         'nfft':1024,
@@ -62,14 +62,14 @@ species = {
         'downsample':1,
         'step': 1/8
     },
-    'Reunion_white_eyes':{
+    'Reunion_grey_white_eyes':{
         'wavpath': 'data/white_eye/cut/*.wav',
         'FS':44100,
         'nfft':1024,
         'downsample':5,
         'step': 1/16
     },
-    'long_billed_hermits':{
+    'long-billed_hermits':{
         'wavpath':'data/marcelo/long_billed_hermit_songs/*.wav',
         'FS':44100,
         'nfft':512,
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000..6d3207c
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,47 @@
+import matplotlib.pyplot as plt
+import argparse, os, tqdm
+import torchcrepe, torch, librosa
+import pandas as pd, numpy as np
+
+parser = argparse.ArgumentParser()
+parser.add_argument('indir', type=str, help="Directory with sound files to process")
+parser.add_argument('--model_path', type=str, help="Path of model weights", default='crepe_ft/model_all.pth')
+parser.add_argument('--compress', type=float, help="Compression factor used to shift frequencies into CREPE's range [32Hz; 2kHz]. \
+    Frequencies are divided by the given factor by artificially changing the sampling rate (slowing down / speeding up the signal).", default=1)
+parser.add_argument('--step', type=float, help="Step used between each prediction (in seconds)", default=256 / torchcrepe.SAMPLE_RATE)
+parser.add_argument('--decoder', choices=['argmax', 'weighted_argmax', 'viterbi'], help="Decoder used to postprocess predictions", default='viterbi')
+parser.add_argument('--print', type=bool, help="Print spectrograms with overlaid F0 predictions to assess their quality", default=False)
+args = parser.parse_args()
+
+# Initialisations
+device, batch_size = ('cuda', 64) if torch.cuda.is_available() else ('cpu', 1)
+model = torchcrepe.Crepe('full').eval().to(device)
+model.load_state_dict(torch.load(args.model_path, map_location=device))
+decoder = torchcrepe.decode.__dict__[args.decoder]
+
+for filename in tqdm.tqdm(os.listdir(args.indir)):
+    try:
+        sig, fs = librosa.load(os.path.join(args.indir, filename), sr=int(torchcrepe.SAMPLE_RATE * args.compress))
+    except:
+        print(f'Failed to load {filename}')
+        continue
+    
+    generator = torchcrepe.core.preprocess(torch.tensor(sig).unsqueeze(0), torchcrepe.SAMPLE_RATE, \
+        hop_length=int(args.step / args.compress * torchcrepe.SAMPLE_RATE), batch_size=batch_size, device=device)
+    with torch.inference_mode():
+        preds = torch.vstack([model(frames).cpu() for frames in generator]).T.unsqueeze(0)
+        f0 = (torchcrepe.core.postprocess(preds, decoder=decoder) * args.compress).squeeze()
+    confidence = preds.max(axis=1)[0].squeeze()
+    time = np.arange(0, len(sig)/fs, args.step)
+    
+    df = pd.DataFrame({'time':time, 'f0':f0, 'confidence':confidence})
+    df.to_csv(os.path.join(args.indir, filename.rsplit('.',1)[0]+'_f0.csv'), index=False)
+    # Plot F0 predictions over spectrograms
+    if args.print:
+        plt.figure(figsize=(max(6.4, 6.4*time[-1]/2), 4.8))
+        plt.specgram(sig, Fs=fs)
+        plt.scatter(time, f0, c=confidence)
+        plt.colorbar()
+        plt.tight_layout()
+        plt.savefig(os.path.join(args.indir, filename.rsplit('.',1)[0])+'_f0.png')
+        plt.close()
\ No newline at end of file
diff --git a/predict_requirements.txt b/predict_requirements.txt
new file mode 100644
index 0000000..72ae954
--- /dev/null
+++ b/predict_requirements.txt
@@ -0,0 +1,7 @@
+librosa==0.10.1
+matplotlib==3.9.2
+numpy==1.23.5
+pandas==1.5.2
+torch==1.13.1+cu117
+torchcrepe==0.0.22
+tqdm==4.64.1
diff --git a/run_all.py b/run_all.py
index 4573c66..260ce8f 100644
--- a/run_all.py
+++ b/run_all.py
@@ -15,12 +15,15 @@ tcrepe_model.load_state_dict(torch.load('/home/paul.best/.local/lib/python3.9/si
 
 cents_mapping = np.linspace(0, 7180, 360) + 1997.3794084376191
 
-def run_tcrepe(model, sig, fs, dt):
+def run_tcrepe(model, sig, fs, dt, viterbi=False):
     generator = torchcrepe.core.preprocess(torch.tensor(sig).unsqueeze(0), fs, hop_length=dt*fs if fs != torchcrepe.SAMPLE_RATE else int(dt*fs),\
                                            batch_size=batch_size, device=device, pad=False)
     with torch.no_grad():
         preds = np.vstack([model(frames).cpu().numpy() for frames in generator])
-    f0 = 10 * 2 ** (crepe.core.to_local_average_cents(preds) / 1200)
+    if viterbi:
+        f0 = 10 * 2 ** (crepe.core.to_viterbi_cents(preds) / 1200)
+    else:
+        f0 = 10 * 2 ** (crepe.core.to_local_average_cents(preds) / 1200)
     confidence = np.max(preds, axis=1)
     time = np.arange(torchcrepe.WINDOW_SIZE/2, len(sig)/fs*torchcrepe.SAMPLE_RATE - torchcrepe.WINDOW_SIZE/2 + 1e-9, dt*torchcrepe.SAMPLE_RATE) / torchcrepe.SAMPLE_RATE
     return time, f0, confidence
@@ -34,7 +37,7 @@ parser.add_argument('--split', type=int, help="Section to test on between 0 and
 args = parser.parse_args()
 
 algos = ['praat_f0','pyin_f0','crepe_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0', 'pesto_ftoth_f0']
-quick_algos = ['praat_f0','pyin_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0']
+quick_algos = ['praat_f0','pyin_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftspV_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0']
 
 if args.overwrite:
     print('Overwriting previous results')
@@ -107,6 +110,12 @@ for specie in species if args.specie =='all' else args.specie.split(' '):
             out.loc[mask, 'tcrepe_ftsp_f0'], out.loc[mask, 'tcrepe_ftsp_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out[mask].time)
             out.tcrepe_ftsp_f0 *= downsample
 
+        if not 'tcrepe_ftspV_f0' in out.columns and tcrepe_ftsp_model: # torch crepe finetuned on the target species
+            time, f0, confidence = run_tcrepe(tcrepe_ftsp_model, sig, fs, dt, viterbi=True)
+            mask = ((out.time > time[0])&(out.time < time[-1]))
+            out.loc[mask, 'tcrepe_ftspV_f0'], out.loc[mask, 'tcrepe_ftspV_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out[mask].time)
+            out.tcrepe_ftspV_f0 *= downsample
+
         if not 'tcrepe_ftoth_f0' in out.columns and tcrepe_ftoth_model: # torch crepe finetuned on other species than the target
             time, f0, confidence = run_tcrepe(tcrepe_ftoth_model, sig, fs, dt)
             mask = ((out.time > time[0])&(out.time < time[-1]))
diff --git a/train_crepe.py b/train_crepe.py
index 8af7ee8..d1e1c64 100644
--- a/train_crepe.py
+++ b/train_crepe.py
@@ -30,8 +30,6 @@ if not os.path.isfile(f'crepe_ft/train_set_{suffix}.pkl'):
         if args.only:
             files = files[:int(len(files)/5*args.split)] + files[int(len(files)/5*(args.split+1)):]
         for fn in tqdm.tqdm(pd.Series(files).sample(min(len(files), 1000)), desc='Peparing dataset for '+specie):
-            if os.path.isfile(f'noisy_pngs/{fn[:-4]}.png'):
-                continue
             annot = pd.read_csv(fn[:-4]+'.csv').drop_duplicates(subset='Time')
             sig, fs = librosa.load(fn, sr=None)
             sig = resampy.resample(sig, fs//downsample, FS)
-- 
GitLab