diff --git a/run_all.py b/run_all.py index ae534c832cddf32bd9255ced258db8d6b0c49f88..4573c667969fa7009c1d54c0b0b3eeb5bd8adb17 100644 --- a/run_all.py +++ b/run_all.py @@ -3,22 +3,21 @@ import pandas as pd, numpy as np, os, argparse, librosa, parselmouth, mir_eval from glob import glob from tqdm import tqdm os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' -import crepe, pesto, tensorflow as tf, torchcrepe, torch, basic_pitch.inference, basic_pitch.constants +import crepe, pesto, torchcrepe, torch, tensorflow as tf, basic_pitch.inference, basic_pitch.constants +device, batch_size = ('cuda', 64) if torch.cuda.is_available() else ('cpu', 1) + +# LOAD MODELS basic_pitch_model = tf.saved_model.load(str(basic_pitch.ICASSP_2022_MODEL_PATH)) -tcrepe_model = torchcrepe.Crepe('full').eval().to('cuda') +tcrepe_model = torchcrepe.Crepe('full').eval().to(device) tcrepe_model.load_state_dict(torch.load('/home/paul.best/.local/lib/python3.9/site-packages/torchcrepe/assets/full.pth', map_location='cuda')) -tcrepe_ft_model = torchcrepe.Crepe('full').eval().to('cuda') -tcrepe_ft_model.load_state_dict(torch.load('crepe_ft/model_all.pth', map_location='cuda')) - cents_mapping = np.linspace(0, 7180, 360) + 1997.3794084376191 -batch_size = 64 def run_tcrepe(model, sig, fs, dt): generator = torchcrepe.core.preprocess(torch.tensor(sig).unsqueeze(0), fs, hop_length=dt*fs if fs != torchcrepe.SAMPLE_RATE else int(dt*fs),\ - batch_size=batch_size, device='cuda', pad=False) + batch_size=batch_size, device=device, pad=False) with torch.no_grad(): preds = np.vstack([model(frames).cpu().numpy() for frames in generator]) f0 = 10 * 2 ** (crepe.core.to_local_average_cents(preds) / 1200) @@ -26,42 +25,59 @@ def run_tcrepe(model, sig, fs, dt): time = np.arange(torchcrepe.WINDOW_SIZE/2, len(sig)/fs*torchcrepe.SAMPLE_RATE - torchcrepe.WINDOW_SIZE/2 + 1e-9, dt*torchcrepe.SAMPLE_RATE) / torchcrepe.SAMPLE_RATE return time, f0, confidence +# PARSE ARGUMENTS parser = argparse.ArgumentParser() parser.add_argument('specie', type=str, help="Species to treat specifically", default=None) parser.add_argument('--overwrite', type=bool, help="Overwrite previous pedictions", default=False) parser.add_argument('--quick', type=bool, help="Skip pyin and crepe to make things quick", default=False) +parser.add_argument('--split', type=int, help="Section to test on between 0 and 4", default=None) args = parser.parse_args() -algos = ['praat_f0','pyin_f0','crepe_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0'] -quick_algos = ['praat_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0'] +algos = ['praat_f0','pyin_f0','crepe_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0', 'pesto_ftoth_f0'] +quick_algos = ['praat_f0','pyin_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0'] if args.overwrite: print('Overwriting previous results') +# Iterate over species, then files, then run each algorithm and save the predictions for specie in species if args.specie =='all' else args.specie.split(' '): wavpath, FS, nfft, downsample, step = species[specie].values() dt = round(nfft * step / FS * downsample, 3) # winsize / 8 - + # Load species specific pre-trained models tcrepe_ftoth_model, tcrepe_ftsp_model = None, None - if os.path.isfile(f'crepe_ft/model_only_{specie}.pth'): - tcrepe_ftsp_model = torchcrepe.Crepe('full').eval().to('cuda') - tcrepe_ftsp_model.load_state_dict(torch.load(f'crepe_ft/model_only_{specie}.pth', map_location='cuda')) + if os.path.isfile(f'crepe_ft/model_only-{args.split}_{specie}.pth'): + tcrepe_ftsp_model = torchcrepe.Crepe('full').eval().to(device) + tcrepe_ftsp_model.load_state_dict(torch.load(f'crepe_ft/model_only-{args.split}_{specie}.pth', map_location=device)) if os.path.isfile(f'crepe_ft/model_omit_{specie}.pth'): - tcrepe_ftoth_model = torchcrepe.Crepe('full').eval().to('cuda') - tcrepe_ftoth_model.load_state_dict(torch.load(f'crepe_ft/model_omit_{specie}.pth', map_location='cuda')) - - for fn in tqdm(glob(wavpath), desc=specie): + tcrepe_ftoth_model = torchcrepe.Crepe('full').eval().to(device) + tcrepe_ftoth_model.load_state_dict(torch.load(f'crepe_ft/model_omit_{specie}.pth', map_location=device)) + # initialise the file list to iterate on + fns = glob(wavpath) + if type(args.split) == int: + fns = fns[int(len(fns)/5*args.split) : int(len(fns)/5*(args.split+1))] + # iterate over files + for fn in tqdm(fns, desc=specie): if args.overwrite or not os.path.isfile(f'{fn[:-4]}_preds.csv') or os.path.getsize(f'{fn[:-4]}_preds.csv') < 300: + # load original annotation file annot = pd.read_csv(f'{fn[:-4]}.csv').drop_duplicates(subset='Time') + # add a 0 at starts and ends for large gaps to avoid interpolating between vocalisations + med_diff = annot.Time.diff().median() + rgaps, lgaps = annot.Time[annot.Time.diff() > med_diff*4], annot.Time[annot.Time.diff(-1) < - med_diff * 4] + annot = pd.concat([annot, pd.DataFrame({'Time':np.concatenate([lgaps+med_diff, rgaps-med_diff]), 'Freq':[0]*(len(lgaps)+len(rgaps))})]).sort_values('Time') + # load the waveform and create the dataframe for storing predictions sig, fs = librosa.load(fn, sr=FS) out = pd.DataFrame({'time':np.arange(nfft/fs/2, (len(sig) - nfft/2)/fs, dt / downsample)}) mask = ((out.time > annot.Time.min())&(out.time < annot.Time.max())) - out.loc[mask, 'annot'] = mir_eval.melody.resample_melody_series(annot.Time, annot.Freq, annot.Freq>0, out.loc[mask, 'time'], verbose=False)[0] + out.loc[mask, 'annot'] = mir_eval.melody.resample_melody_series(annot.Time, annot.Freq, annot.Freq>0, out[mask].time, verbose=False)[0] else: - out = pd.read_csv(f'{fn[:-4]}_preds.csv').dropna(axis=1, how='all') # drop a column if all values are None + out = pd.read_csv(f'{fn[:-4]}_preds.csv') + for algo in algos: # drop a column if all values are None + if algo in out.columns and out[algo].isna().all(): + out.drop(algo, axis=1, inplace=True) + # check if everything has already been computed, and if yes skip the file if pd.Series(algos).isin(out.columns).all() or (args.quick and pd.Series(quick_algos).isin(out.columns).all()): - continue + continue sig, fs = librosa.load(fn, sr=FS) out.time *= downsample @@ -74,23 +90,17 @@ for specie in species if args.specie =='all' else args.specie.split(' '): out.loc[mask, 'praat_f0'], out.loc[mask, 'praat_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out.loc[mask, 'time']) out.praat_f0 *= downsample - if not args.quick and not 'pyin_f0' in out.columns: # PYIN + if not 'pyin_f0' in out.columns: # PYIN f0, voiced, prob = librosa.pyin(sig, sr=fs, fmin=27.5, fmax=fs//2, frame_length=nfft, hop_length=int(fs*dt), center=False) out['pyin_f0'], out['pyin_conf'] = f0[:len(out)], prob[:len(out)] out.pyin_f0 *= downsample - if not 'tcrepe_f0' in out.columns: # torch crepe pretrained + if not 'tcrepe_f0' in out.columns: # torch crepe out-of-the-box time, f0, confidence = run_tcrepe(tcrepe_model, sig, fs, dt) mask = ((out.time > time[0])&(out.time < time[-1])) out.loc[mask, 'tcrepe_f0'], out.loc[mask, 'tcrepe_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out[mask].time) out.tcrepe_f0 *= downsample - if not 'tcrepe_ft_f0' in out.columns: # torch crepe finetuned on all species - time, f0, confidence = run_tcrepe(tcrepe_ft_model, sig, fs, dt) - mask = ((out.time > time[0])&(out.time < time[-1])) - out.loc[mask, 'tcrepe_ft_f0'], out.loc[mask, 'tcrepe_ft_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out[mask].time) - out.tcrepe_ft_f0 *= downsample - if not 'tcrepe_ftsp_f0' in out.columns and tcrepe_ftsp_model: # torch crepe finetuned on the target species time, f0, confidence = run_tcrepe(tcrepe_ftsp_model, sig, fs, dt) mask = ((out.time > time[0])&(out.time < time[-1])) @@ -103,7 +113,7 @@ for specie in species if args.specie =='all' else args.specie.split(' '): out.loc[mask, 'tcrepe_ftoth_f0'], out.loc[mask, 'tcrepe_ftoth_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out[mask].time) out.tcrepe_ftoth_f0 *= downsample - if not args.quick and not 'crepe_f0' in out.columns: # CREPE + if not args.quick and not 'crepe_f0' in out.columns: # CREPE out-of-the-box tensorflow time, f0, confidence, activation = crepe.predict(sig, fs, step_size=int(dt*1e3), center=False, verbose=0) # step_size in ms out['crepe_f0'], out['crepe_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out.time) out.crepe_f0 *= downsample @@ -116,12 +126,12 @@ for specie in species if args.specie =='all' else args.specie.split(' '): out['basic_f0'], out['basic_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out.time) out.basic_f0 *= downsample - if not 'pesto_f0' in out.columns: # pesto + if not 'pesto_f0' in out.columns: # pesto out-of-the-box try: time, f0, confidence, activation = pesto.predict(torch.tensor(sig).unsqueeze(0), fs, step_size=int(dt*1e3), convert_to_freq=True) # step_size in ms out['pesto_f0'], out['pesto_conf'] = mir_eval.melody.resample_melody_series(time/1000, f0[0], confidence.numpy(), out.time, verbose=False) out.pesto_f0 *= downsample - except Exception as inst: + except: out['pesto_f0'], out['pesto_conf'] = None, None if not 'pesto_ft_f0' in out.columns and os.path.isfile(f'pesto_ft/{specie}.pth'): # pesto finetuned diff --git a/train_crepe.py b/train_crepe.py index 852608a52b0eb7d5564a746181b59c43820a042e..8af7ee8e8009ba5a66aea96ade33aaa43aba93ae 100644 --- a/train_crepe.py +++ b/train_crepe.py @@ -8,9 +8,10 @@ from torch.utils.tensorboard import SummaryWriter parser = argparse.ArgumentParser() parser.add_argument('--omit', type=str, help="Species to rule out of the training set", default=None) parser.add_argument('--only', type=str, help="Train only on the given species", default=None) +parser.add_argument('--split', type=int, help="Portion out of between 0 and 4 to use as test set", default=0) args = parser.parse_args() -suffix = "omit_"+args.omit if args.omit else "only_"+args.only if args.only else "all" +suffix = "omit_"+args.omit if args.omit else f"only-{args.split}_"+args.only if args.only else "all" writer = SummaryWriter('runs/'+suffix) model = torchcrepe.Crepe('full') @@ -26,6 +27,8 @@ if not os.path.isfile(f'crepe_ft/train_set_{suffix}.pkl'): wavpath, fs, nfft, downsample, step = species[specie].values() dt = int(n_in * step) # winsize / 8 files = glob.glob(wavpath) + if args.only: + files = files[:int(len(files)/5*args.split)] + files[int(len(files)/5*(args.split+1)):] for fn in tqdm.tqdm(pd.Series(files).sample(min(len(files), 1000)), desc='Peparing dataset for '+specie): if os.path.isfile(f'noisy_pngs/{fn[:-4]}.png'): continue