Skip to content
Snippets Groups Projects
Commit 245880da authored by Paul Best's avatar Paul Best
Browse files

k-fold training and inference

parent 3ca2b65f
No related branches found
No related tags found
No related merge requests found
......@@ -3,22 +3,21 @@ import pandas as pd, numpy as np, os, argparse, librosa, parselmouth, mir_eval
from glob import glob
from tqdm import tqdm
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import crepe, pesto, tensorflow as tf, torchcrepe, torch, basic_pitch.inference, basic_pitch.constants
import crepe, pesto, torchcrepe, torch, tensorflow as tf, basic_pitch.inference, basic_pitch.constants
device, batch_size = ('cuda', 64) if torch.cuda.is_available() else ('cpu', 1)
# LOAD MODELS
basic_pitch_model = tf.saved_model.load(str(basic_pitch.ICASSP_2022_MODEL_PATH))
tcrepe_model = torchcrepe.Crepe('full').eval().to('cuda')
tcrepe_model = torchcrepe.Crepe('full').eval().to(device)
tcrepe_model.load_state_dict(torch.load('/home/paul.best/.local/lib/python3.9/site-packages/torchcrepe/assets/full.pth', map_location='cuda'))
tcrepe_ft_model = torchcrepe.Crepe('full').eval().to('cuda')
tcrepe_ft_model.load_state_dict(torch.load('crepe_ft/model_all.pth', map_location='cuda'))
cents_mapping = np.linspace(0, 7180, 360) + 1997.3794084376191
batch_size = 64
def run_tcrepe(model, sig, fs, dt):
generator = torchcrepe.core.preprocess(torch.tensor(sig).unsqueeze(0), fs, hop_length=dt*fs if fs != torchcrepe.SAMPLE_RATE else int(dt*fs),\
batch_size=batch_size, device='cuda', pad=False)
batch_size=batch_size, device=device, pad=False)
with torch.no_grad():
preds = np.vstack([model(frames).cpu().numpy() for frames in generator])
f0 = 10 * 2 ** (crepe.core.to_local_average_cents(preds) / 1200)
......@@ -26,39 +25,56 @@ def run_tcrepe(model, sig, fs, dt):
time = np.arange(torchcrepe.WINDOW_SIZE/2, len(sig)/fs*torchcrepe.SAMPLE_RATE - torchcrepe.WINDOW_SIZE/2 + 1e-9, dt*torchcrepe.SAMPLE_RATE) / torchcrepe.SAMPLE_RATE
return time, f0, confidence
# PARSE ARGUMENTS
parser = argparse.ArgumentParser()
parser.add_argument('specie', type=str, help="Species to treat specifically", default=None)
parser.add_argument('--overwrite', type=bool, help="Overwrite previous pedictions", default=False)
parser.add_argument('--quick', type=bool, help="Skip pyin and crepe to make things quick", default=False)
parser.add_argument('--split', type=int, help="Section to test on between 0 and 4", default=None)
args = parser.parse_args()
algos = ['praat_f0','pyin_f0','crepe_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0']
quick_algos = ['praat_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0']
algos = ['praat_f0','pyin_f0','crepe_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0', 'pesto_ftoth_f0']
quick_algos = ['praat_f0','pyin_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0']
if args.overwrite:
print('Overwriting previous results')
# Iterate over species, then files, then run each algorithm and save the predictions
for specie in species if args.specie =='all' else args.specie.split(' '):
wavpath, FS, nfft, downsample, step = species[specie].values()
dt = round(nfft * step / FS * downsample, 3) # winsize / 8
# Load species specific pre-trained models
tcrepe_ftoth_model, tcrepe_ftsp_model = None, None
if os.path.isfile(f'crepe_ft/model_only_{specie}.pth'):
tcrepe_ftsp_model = torchcrepe.Crepe('full').eval().to('cuda')
tcrepe_ftsp_model.load_state_dict(torch.load(f'crepe_ft/model_only_{specie}.pth', map_location='cuda'))
if os.path.isfile(f'crepe_ft/model_only-{args.split}_{specie}.pth'):
tcrepe_ftsp_model = torchcrepe.Crepe('full').eval().to(device)
tcrepe_ftsp_model.load_state_dict(torch.load(f'crepe_ft/model_only-{args.split}_{specie}.pth', map_location=device))
if os.path.isfile(f'crepe_ft/model_omit_{specie}.pth'):
tcrepe_ftoth_model = torchcrepe.Crepe('full').eval().to('cuda')
tcrepe_ftoth_model.load_state_dict(torch.load(f'crepe_ft/model_omit_{specie}.pth', map_location='cuda'))
for fn in tqdm(glob(wavpath), desc=specie):
tcrepe_ftoth_model = torchcrepe.Crepe('full').eval().to(device)
tcrepe_ftoth_model.load_state_dict(torch.load(f'crepe_ft/model_omit_{specie}.pth', map_location=device))
# initialise the file list to iterate on
fns = glob(wavpath)
if type(args.split) == int:
fns = fns[int(len(fns)/5*args.split) : int(len(fns)/5*(args.split+1))]
# iterate over files
for fn in tqdm(fns, desc=specie):
if args.overwrite or not os.path.isfile(f'{fn[:-4]}_preds.csv') or os.path.getsize(f'{fn[:-4]}_preds.csv') < 300:
# load original annotation file
annot = pd.read_csv(f'{fn[:-4]}.csv').drop_duplicates(subset='Time')
# add a 0 at starts and ends for large gaps to avoid interpolating between vocalisations
med_diff = annot.Time.diff().median()
rgaps, lgaps = annot.Time[annot.Time.diff() > med_diff*4], annot.Time[annot.Time.diff(-1) < - med_diff * 4]
annot = pd.concat([annot, pd.DataFrame({'Time':np.concatenate([lgaps+med_diff, rgaps-med_diff]), 'Freq':[0]*(len(lgaps)+len(rgaps))})]).sort_values('Time')
# load the waveform and create the dataframe for storing predictions
sig, fs = librosa.load(fn, sr=FS)
out = pd.DataFrame({'time':np.arange(nfft/fs/2, (len(sig) - nfft/2)/fs, dt / downsample)})
mask = ((out.time > annot.Time.min())&(out.time < annot.Time.max()))
out.loc[mask, 'annot'] = mir_eval.melody.resample_melody_series(annot.Time, annot.Freq, annot.Freq>0, out.loc[mask, 'time'], verbose=False)[0]
out.loc[mask, 'annot'] = mir_eval.melody.resample_melody_series(annot.Time, annot.Freq, annot.Freq>0, out[mask].time, verbose=False)[0]
else:
out = pd.read_csv(f'{fn[:-4]}_preds.csv').dropna(axis=1, how='all') # drop a column if all values are None
out = pd.read_csv(f'{fn[:-4]}_preds.csv')
for algo in algos: # drop a column if all values are None
if algo in out.columns and out[algo].isna().all():
out.drop(algo, axis=1, inplace=True)
# check if everything has already been computed, and if yes skip the file
if pd.Series(algos).isin(out.columns).all() or (args.quick and pd.Series(quick_algos).isin(out.columns).all()):
continue
......@@ -74,23 +90,17 @@ for specie in species if args.specie =='all' else args.specie.split(' '):
out.loc[mask, 'praat_f0'], out.loc[mask, 'praat_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out.loc[mask, 'time'])
out.praat_f0 *= downsample
if not args.quick and not 'pyin_f0' in out.columns: # PYIN
if not 'pyin_f0' in out.columns: # PYIN
f0, voiced, prob = librosa.pyin(sig, sr=fs, fmin=27.5, fmax=fs//2, frame_length=nfft, hop_length=int(fs*dt), center=False)
out['pyin_f0'], out['pyin_conf'] = f0[:len(out)], prob[:len(out)]
out.pyin_f0 *= downsample
if not 'tcrepe_f0' in out.columns: # torch crepe pretrained
if not 'tcrepe_f0' in out.columns: # torch crepe out-of-the-box
time, f0, confidence = run_tcrepe(tcrepe_model, sig, fs, dt)
mask = ((out.time > time[0])&(out.time < time[-1]))
out.loc[mask, 'tcrepe_f0'], out.loc[mask, 'tcrepe_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out[mask].time)
out.tcrepe_f0 *= downsample
if not 'tcrepe_ft_f0' in out.columns: # torch crepe finetuned on all species
time, f0, confidence = run_tcrepe(tcrepe_ft_model, sig, fs, dt)
mask = ((out.time > time[0])&(out.time < time[-1]))
out.loc[mask, 'tcrepe_ft_f0'], out.loc[mask, 'tcrepe_ft_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out[mask].time)
out.tcrepe_ft_f0 *= downsample
if not 'tcrepe_ftsp_f0' in out.columns and tcrepe_ftsp_model: # torch crepe finetuned on the target species
time, f0, confidence = run_tcrepe(tcrepe_ftsp_model, sig, fs, dt)
mask = ((out.time > time[0])&(out.time < time[-1]))
......@@ -103,7 +113,7 @@ for specie in species if args.specie =='all' else args.specie.split(' '):
out.loc[mask, 'tcrepe_ftoth_f0'], out.loc[mask, 'tcrepe_ftoth_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out[mask].time)
out.tcrepe_ftoth_f0 *= downsample
if not args.quick and not 'crepe_f0' in out.columns: # CREPE
if not args.quick and not 'crepe_f0' in out.columns: # CREPE out-of-the-box tensorflow
time, f0, confidence, activation = crepe.predict(sig, fs, step_size=int(dt*1e3), center=False, verbose=0) # step_size in ms
out['crepe_f0'], out['crepe_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out.time)
out.crepe_f0 *= downsample
......@@ -116,12 +126,12 @@ for specie in species if args.specie =='all' else args.specie.split(' '):
out['basic_f0'], out['basic_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out.time)
out.basic_f0 *= downsample
if not 'pesto_f0' in out.columns: # pesto
if not 'pesto_f0' in out.columns: # pesto out-of-the-box
try:
time, f0, confidence, activation = pesto.predict(torch.tensor(sig).unsqueeze(0), fs, step_size=int(dt*1e3), convert_to_freq=True) # step_size in ms
out['pesto_f0'], out['pesto_conf'] = mir_eval.melody.resample_melody_series(time/1000, f0[0], confidence.numpy(), out.time, verbose=False)
out.pesto_f0 *= downsample
except Exception as inst:
except:
out['pesto_f0'], out['pesto_conf'] = None, None
if not 'pesto_ft_f0' in out.columns and os.path.isfile(f'pesto_ft/{specie}.pth'): # pesto finetuned
......
......@@ -8,9 +8,10 @@ from torch.utils.tensorboard import SummaryWriter
parser = argparse.ArgumentParser()
parser.add_argument('--omit', type=str, help="Species to rule out of the training set", default=None)
parser.add_argument('--only', type=str, help="Train only on the given species", default=None)
parser.add_argument('--split', type=int, help="Portion out of between 0 and 4 to use as test set", default=0)
args = parser.parse_args()
suffix = "omit_"+args.omit if args.omit else "only_"+args.only if args.only else "all"
suffix = "omit_"+args.omit if args.omit else f"only-{args.split}_"+args.only if args.only else "all"
writer = SummaryWriter('runs/'+suffix)
model = torchcrepe.Crepe('full')
......@@ -26,6 +27,8 @@ if not os.path.isfile(f'crepe_ft/train_set_{suffix}.pkl'):
wavpath, fs, nfft, downsample, step = species[specie].values()
dt = int(n_in * step) # winsize / 8
files = glob.glob(wavpath)
if args.only:
files = files[:int(len(files)/5*args.split)] + files[int(len(files)/5*(args.split+1)):]
for fn in tqdm.tqdm(pd.Series(files).sample(min(len(files), 1000)), desc='Peparing dataset for '+specie):
if os.path.isfile(f'noisy_pngs/{fn[:-4]}.png'):
continue
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment