Skip to content
Snippets Groups Projects
Commit a9398aac authored by Paul Best's avatar Paul Best
Browse files

add pesto_ft run

parent 285f766a
No related branches found
No related tags found
No related merge requests found
import pandas as pd, numpy as np import pandas as pd, numpy as np
from sklearn import metrics as skmetrics from sklearn import metrics as skmetrics
import mir_eval.melody import mir_eval.melody
import os, metadata import os, argparse
from metadata import species
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from glob import glob from glob import glob
from p_tqdm import p_umap from p_tqdm import p_umap
...@@ -9,16 +10,19 @@ from tqdm import tqdm ...@@ -9,16 +10,19 @@ from tqdm import tqdm
cent_thr = 50 cent_thr = 50
metrics = ['recall', 'FA', 'pitch_acc', 'chroma_acc', 'diff_distrib'] metrics = ['recall', 'FA', 'pitch_acc', 'chroma_acc', 'diff_distrib']
drop_noisy_vocs = False parser = argparse.ArgumentParser()
drop_noisy_bins = False parser.add_argument('specie', type=str, help="Species to treat specifically", default=None)
parser.add_argument('--drop_noisy_bins', type=bool, help="drop noisy vocalisations", default=False)
parser.add_argument('--drop_noisy_vocs', type=bool, help="drop noisy STFT bins", default=False)
args = parser.parse_args()
for specie in metadata.species: for specie in species if args.specie is None else args.specie.split(' '):
algos = {'pyin', 'praat', 'crepe', 'tcrepe', 'tcrepe_ft', 'basic', 'pesto', 'tcrepe_ftsp'} algos = {'pyin', 'praat', 'crepe', 'tcrepe', 'tcrepe_ft', 'basic', 'pesto', 'tcrepe_ftsp', 'pesto_ft'}
# Get optimal thresholds # Get optimal thresholds
confs = {k:[] for k in algos} confs = {k:[] for k in algos}
confs['label'] = [] confs['label'] = []
for fn in tqdm(glob(metadata.species[specie]['wavpath'][:-4]+'_preds.csv'), desc=f'{specie} get thrs', leave=False): for fn in tqdm(glob(species[specie]['wavpath'][:-4]+'_preds.csv'), desc=f'{specie} get thrs', leave=False):
if drop_noisy_vocs and os.path.isfile(f'noisy_pngs/{fn[:-10]}.png'): if args.drop_noisy_vocs and os.path.isfile(f'noisy_pngs/{fn[:-10]}.png'):
continue continue
df = pd.read_csv(fn) df = pd.read_csv(fn)
for algo in algos: for algo in algos:
...@@ -35,16 +39,18 @@ for specie in metadata.species: ...@@ -35,16 +39,18 @@ for specie in metadata.species:
# Compute recall, false alarm, pitch acc and chroma acc # Compute recall, false alarm, pitch acc and chroma acc
def fun(fn): def fun(fn):
if drop_noisy_vocs and os.path.isfile(f'noisy_pngs/{fn[:-10]}.png'): if args.drop_noisy_vocs and os.path.isfile(f'noisy_pngs/{fn[:-10]}.png'):
return pd.DataFrame() return pd.DataFrame()
df = pd.read_csv(fn).fillna(0) df = pd.read_csv(fn).fillna(0)
df.annot = mir_eval.melody.hz2cents(df.annot) df.annot = mir_eval.melody.hz2cents(df.annot)
if drop_noisy_bins and 'salience' in df.columns: if args.drop_noisy_bins and 'salience' in df.columns:
df.loc[((df.salience < 0.2) | (df.SHR > 10*np.log10(0.2))), 'annot'] = 0 df.loc[((df.salience < 0.2) | (df.SHR > 10*np.log10(0.2))), 'annot'] = 0
if not (df.annot > 0).any(): if not (df.annot > 0).any():
return pd.DataFrame() return pd.DataFrame()
out = pd.DataFrame(columns=metrics) out = pd.DataFrame(columns=metrics)
for algo in algos: for algo in algos:
if df[algo+'_f0'].isna().all():
continue
out.loc[algo, ['Recall', 'False alarm']] = mir_eval.melody.voicing_measures(df.annot > 0, df[algo+'_conf'] > thrs[algo]) out.loc[algo, ['Recall', 'False alarm']] = mir_eval.melody.voicing_measures(df.annot > 0, df[algo+'_conf'] > thrs[algo])
df[algo+'_f0'] = mir_eval.melody.hz2cents(df[algo+'_f0']) df[algo+'_f0'] = mir_eval.melody.hz2cents(df[algo+'_f0'])
df[algo+'_conf'].clip(0, 1, inplace=True) df[algo+'_conf'].clip(0, 1, inplace=True)
...@@ -53,7 +59,7 @@ for specie in metadata.species: ...@@ -53,7 +59,7 @@ for specie in metadata.species:
out.at[algo, 'diff_distrib'] = list(abs(df[algo+'_f0'] - df.annot)) out.at[algo, 'diff_distrib'] = list(abs(df[algo+'_f0'] - df.annot))
return out return out
df = pd.concat(p_umap(fun, glob(metadata.species[specie]['wavpath'][:-4]+'_preds.csv'), desc=f'{specie} get perf')) df = pd.concat(p_umap(fun, glob(species[specie]['wavpath'][:-4]+'_preds.csv'), desc=f'{specie} get perf'))
fig, ax = plt.subplots(ncols=3, figsize=(12, 4), sharex=True) fig, ax = plt.subplots(ncols=3, figsize=(12, 4), sharex=True)
for i, algo in enumerate(algos): for i, algo in enumerate(algos):
...@@ -67,12 +73,12 @@ for specie in metadata.species: ...@@ -67,12 +73,12 @@ for specie in metadata.species:
ax[2].set_title('Distrib of chroma acc per vocs in % ') ax[2].set_title('Distrib of chroma acc per vocs in % ')
plt.xticks(np.arange(len(algos)), algos, rotation=45) plt.xticks(np.arange(len(algos)), algos, rotation=45)
plt.tight_layout() plt.tight_layout()
plt.savefig(f'scores/{specie}_report{"_minusvocs" if drop_noisy_vocs else ""}{"_minusbins" if drop_noisy_bins else ""}.pdf') plt.savefig(f'scores/{specie}_report{"_minusvocs" if args.drop_noisy_vocs else ""}{"_minusbins" if args.drop_noisy_bins else ""}.pdf')
plt.close() plt.close()
df = df.reset_index(names='algo').groupby('algo').agg({'algo':'count', 'Recall':'mean', 'False alarm':'mean', 'Pitch acc':'mean', 'Chroma acc':'mean'}) df = df.reset_index(names='algo').groupby('algo').agg({'algo':'count', 'Recall':'mean', 'False alarm':'mean', 'Pitch acc':'mean', 'Chroma acc':'mean'})
df.loc[thrs.keys(), 'threshold'] = list(thrs.values()) df.loc[thrs.keys(), 'threshold'] = list(thrs.values())
df.rename(columns={'algo':'count'}, inplace=True) df.rename(columns={'algo':'count'}, inplace=True)
print(df) print(df)
df.to_csv(f'scores/{specie}_scores{"_minusvocs" if drop_noisy_vocs else ""}{"_minusbins" if drop_noisy_bins else ""}.csv') df.to_csv(f'scores/{specie}_scores{"_minusvocs" if args.drop_noisy_vocs else ""}{"_minusbins" if args.drop_noisy_bins else ""}.csv')
# df.to_latex(f'{specie}_scores.tex', float_format=lambda d: f'{d:.2f}') # df.to_latex(f'{specie}_scores.tex', float_format=lambda d: f'{d:.2f}')
\ No newline at end of file
from metadata import species from metadata import species
import pandas as pd, numpy as np, os, librosa, parselmouth, mir_eval import pandas as pd, numpy as np, os, argparse, librosa, parselmouth, mir_eval
from glob import glob from glob import glob
from tqdm import tqdm from tqdm import tqdm
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
...@@ -21,7 +21,7 @@ if overwrite: ...@@ -21,7 +21,7 @@ if overwrite:
print('Overwriting previous results') print('Overwriting previous results')
def run_tcrepe(model, sig, fs, dt): def run_tcrepe(model, sig, fs, dt):
generator = torchcrepe.core.preprocess(torch.tensor(sig).unsqueeze(0), fs, hop_length=dt*fs if fs != torchcrepe.SAMPLE_RATE else int(dt*fs), batch_size=batch_size, device='cuda:1', pad=False) generator = torchcrepe.core.preprocess(torch.tensor(sig).unsqueeze(0), fs, hop_length=dt*fs if fs != torchcrepe.SAMPLE_RATE else int(dt*fs), batch_size=batch_size, device='cuda', pad=False)
with torch.no_grad(): with torch.no_grad():
preds = np.vstack([model(frames).cpu().numpy() for frames in generator]) preds = np.vstack([model(frames).cpu().numpy() for frames in generator])
f0 = 10 * 2 ** (crepe.core.to_local_average_cents(preds) / 1200) f0 = 10 * 2 ** (crepe.core.to_local_average_cents(preds) / 1200)
...@@ -29,8 +29,11 @@ def run_tcrepe(model, sig, fs, dt): ...@@ -29,8 +29,11 @@ def run_tcrepe(model, sig, fs, dt):
time = np.arange(torchcrepe.WINDOW_SIZE/2, len(sig)/fs*torchcrepe.SAMPLE_RATE - torchcrepe.WINDOW_SIZE/2 + 1e-9, dt*torchcrepe.SAMPLE_RATE) / torchcrepe.SAMPLE_RATE time = np.arange(torchcrepe.WINDOW_SIZE/2, len(sig)/fs*torchcrepe.SAMPLE_RATE - torchcrepe.WINDOW_SIZE/2 + 1e-9, dt*torchcrepe.SAMPLE_RATE) / torchcrepe.SAMPLE_RATE
return time, f0, confidence return time, f0, confidence
parser = argparse.ArgumentParser()
parser.add_argument('specie', type=str, help="Species to treat specifically", default=None)
args = parser.parse_args()
for specie in species: for specie in species if args.specie is None else args.specie.split(' '):
wavpath, FS, nfft, downsample = species[specie].values() wavpath, FS, nfft, downsample = species[specie].values()
dt = round(nfft / 8 / FS * downsample, 3) # winsize / 8 dt = round(nfft / 8 / FS * downsample, 3) # winsize / 8
...@@ -49,7 +52,7 @@ for specie in species: ...@@ -49,7 +52,7 @@ for specie in species:
out.loc[mask, 'annot'] = mir_eval.melody.resample_melody_series(annot.Time, annot.Freq, annot.Freq>0, out.loc[mask, 'time'], verbose=False)[0] out.loc[mask, 'annot'] = mir_eval.melody.resample_melody_series(annot.Time, annot.Freq, annot.Freq>0, out.loc[mask, 'time'], verbose=False)[0]
else: else:
out = pd.read_csv(f'{fn[:-4]}_preds.csv') out = pd.read_csv(f'{fn[:-4]}_preds.csv')
if pd.Series(['praat_f0','pyin_f0','crepe_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','basic_f0','pesto_f0']).isin(out.columns).all(): if pd.Series(['praat_f0','pyin_f0','crepe_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','basic_f0','pesto_f0', 'pesto_ftsp_f0']).isin(out.columns).all():
continue continue
sig, fs = librosa.load(fn, sr=FS) sig, fs = librosa.load(fn, sr=FS)
...@@ -103,11 +106,18 @@ for specie in species: ...@@ -103,11 +106,18 @@ for specie in species:
if not 'pesto_f0' in out.columns: # pesto if not 'pesto_f0' in out.columns: # pesto
try: try:
time, f0, confidence, activation = pesto.predict(torch.tensor(sig).unsqueeze(0), fs, step_size=int(dt*1e3), convert_to_freq=True) # step_size in ms time, f0, confidence, activation = pesto.predict(torch.tensor(sig).unsqueeze(0), fs, step_size=int(dt*1e3), convert_to_freq=True) # step_size in ms
out['pesto_f0'], out['pesto_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence.numpy(), out.time, verbose=False) out['pesto_f0'], out['pesto_conf'] = mir_eval.melody.resample_melody_series(time/1000, f0[0], confidence.numpy(), out.time, verbose=False)
out.pesto_f0 *= downsample out.pesto_f0 *= downsample
except Exception as inst: except Exception as inst:
out['pesto_f0'], out['pesto_conf'] = None, None out['pesto_f0'], out['pesto_conf'] = None, None
# print('pesto failed with '+fn, inst)
if not 'pesto_ftsp_f0' in out.columns and os.path.isfile(f'pesto_ft/{specie}.pth'): # pesto finetuned
try:
time, f0, confidence, activation = pesto.predict(torch.tensor(sig).unsqueeze(0), fs, model_name=f'pesto_ft/{specie}.pth', step_size=int(dt*1e3), convert_to_freq=True) # step_size in ms
out['pesto_ft_f0'], out['pesto_ft_conf'] = mir_eval.melody.resample_melody_series(time/1000 + 1e-6, f0[0], confidence.numpy(), out.time, verbose=False)
out.pesto_ft_f0 *= downsample
except Exception as inst:
out['pesto_ft_f0'], out['pesto_ft_conf'] = None, None
out.annot *= downsample out.annot *= downsample
out.time /= downsample out.time /= downsample
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment