diff --git a/run_all.py b/run_all.py
index ae534c832cddf32bd9255ced258db8d6b0c49f88..4573c667969fa7009c1d54c0b0b3eeb5bd8adb17 100644
--- a/run_all.py
+++ b/run_all.py
@@ -3,22 +3,21 @@ import pandas as pd, numpy as np, os, argparse, librosa, parselmouth, mir_eval
 from glob import glob
 from tqdm import tqdm
 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
-import crepe, pesto, tensorflow as tf, torchcrepe, torch, basic_pitch.inference, basic_pitch.constants
+import crepe, pesto, torchcrepe, torch, tensorflow as tf, basic_pitch.inference, basic_pitch.constants
 
+device, batch_size = ('cuda', 64) if torch.cuda.is_available() else ('cpu', 1)
+
+# LOAD MODELS
 basic_pitch_model = tf.saved_model.load(str(basic_pitch.ICASSP_2022_MODEL_PATH))
 
-tcrepe_model = torchcrepe.Crepe('full').eval().to('cuda')
+tcrepe_model = torchcrepe.Crepe('full').eval().to(device)
 tcrepe_model.load_state_dict(torch.load('/home/paul.best/.local/lib/python3.9/site-packages/torchcrepe/assets/full.pth', map_location='cuda'))
 
-tcrepe_ft_model = torchcrepe.Crepe('full').eval().to('cuda')
-tcrepe_ft_model.load_state_dict(torch.load('crepe_ft/model_all.pth', map_location='cuda'))
-
 cents_mapping = np.linspace(0, 7180, 360) + 1997.3794084376191
-batch_size = 64
 
 def run_tcrepe(model, sig, fs, dt):
     generator = torchcrepe.core.preprocess(torch.tensor(sig).unsqueeze(0), fs, hop_length=dt*fs if fs != torchcrepe.SAMPLE_RATE else int(dt*fs),\
-                                           batch_size=batch_size, device='cuda', pad=False)
+                                           batch_size=batch_size, device=device, pad=False)
     with torch.no_grad():
         preds = np.vstack([model(frames).cpu().numpy() for frames in generator])
     f0 = 10 * 2 ** (crepe.core.to_local_average_cents(preds) / 1200)
@@ -26,42 +25,59 @@ def run_tcrepe(model, sig, fs, dt):
     time = np.arange(torchcrepe.WINDOW_SIZE/2, len(sig)/fs*torchcrepe.SAMPLE_RATE - torchcrepe.WINDOW_SIZE/2 + 1e-9, dt*torchcrepe.SAMPLE_RATE) / torchcrepe.SAMPLE_RATE
     return time, f0, confidence
 
+# PARSE ARGUMENTS
 parser = argparse.ArgumentParser()
 parser.add_argument('specie', type=str, help="Species to treat specifically", default=None)
 parser.add_argument('--overwrite', type=bool, help="Overwrite previous pedictions", default=False)
 parser.add_argument('--quick', type=bool, help="Skip pyin and crepe to make things quick", default=False)
+parser.add_argument('--split', type=int, help="Section to test on between 0 and 4", default=None)
 args = parser.parse_args()
 
-algos = ['praat_f0','pyin_f0','crepe_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0']
-quick_algos = ['praat_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0']
+algos = ['praat_f0','pyin_f0','crepe_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0', 'pesto_ftoth_f0']
+quick_algos = ['praat_f0','pyin_f0','tcrepe_f0','tcrepe_ft_f0','tcrepe_ftsp_f0','tcrepe_ftoth_f0', 'basic_f0','pesto_f0', 'pesto_ft_f0']
 
 if args.overwrite:
     print('Overwriting previous results')
 
+# Iterate over species, then files, then run each algorithm and save the predictions
 for specie in species if args.specie =='all' else args.specie.split(' '):
     wavpath, FS, nfft, downsample, step = species[specie].values()
     dt = round(nfft * step / FS * downsample, 3) # winsize / 8
-
+    # Load species specific pre-trained models
     tcrepe_ftoth_model, tcrepe_ftsp_model = None, None
-    if os.path.isfile(f'crepe_ft/model_only_{specie}.pth'):
-        tcrepe_ftsp_model = torchcrepe.Crepe('full').eval().to('cuda')
-        tcrepe_ftsp_model.load_state_dict(torch.load(f'crepe_ft/model_only_{specie}.pth', map_location='cuda'))
+    if os.path.isfile(f'crepe_ft/model_only-{args.split}_{specie}.pth'):
+        tcrepe_ftsp_model = torchcrepe.Crepe('full').eval().to(device)
+        tcrepe_ftsp_model.load_state_dict(torch.load(f'crepe_ft/model_only-{args.split}_{specie}.pth', map_location=device))
     if os.path.isfile(f'crepe_ft/model_omit_{specie}.pth'):
-        tcrepe_ftoth_model = torchcrepe.Crepe('full').eval().to('cuda')
-        tcrepe_ftoth_model.load_state_dict(torch.load(f'crepe_ft/model_omit_{specie}.pth', map_location='cuda'))
-
-    for fn in tqdm(glob(wavpath), desc=specie):
+        tcrepe_ftoth_model = torchcrepe.Crepe('full').eval().to(device)
+        tcrepe_ftoth_model.load_state_dict(torch.load(f'crepe_ft/model_omit_{specie}.pth', map_location=device))
+    # initialise the file list to iterate on
+    fns = glob(wavpath)
+    if type(args.split) == int:
+        fns = fns[int(len(fns)/5*args.split) : int(len(fns)/5*(args.split+1))]
+    # iterate over files
+    for fn in tqdm(fns, desc=specie):
         if args.overwrite or not os.path.isfile(f'{fn[:-4]}_preds.csv') or os.path.getsize(f'{fn[:-4]}_preds.csv') < 300:
+            # load original annotation file
             annot = pd.read_csv(f'{fn[:-4]}.csv').drop_duplicates(subset='Time')
+            # add a 0 at starts and ends for large gaps to avoid interpolating between vocalisations
+            med_diff = annot.Time.diff().median()
+            rgaps, lgaps = annot.Time[annot.Time.diff() > med_diff*4], annot.Time[annot.Time.diff(-1) < - med_diff * 4]
+            annot = pd.concat([annot, pd.DataFrame({'Time':np.concatenate([lgaps+med_diff, rgaps-med_diff]), 'Freq':[0]*(len(lgaps)+len(rgaps))})]).sort_values('Time')
+            # load the waveform and create the dataframe for storing predictions
             sig, fs = librosa.load(fn, sr=FS)
             out = pd.DataFrame({'time':np.arange(nfft/fs/2, (len(sig) - nfft/2)/fs, dt / downsample)})
             mask = ((out.time > annot.Time.min())&(out.time < annot.Time.max()))
-            out.loc[mask, 'annot'] = mir_eval.melody.resample_melody_series(annot.Time, annot.Freq, annot.Freq>0, out.loc[mask, 'time'], verbose=False)[0]
+            out.loc[mask, 'annot'] = mir_eval.melody.resample_melody_series(annot.Time, annot.Freq, annot.Freq>0, out[mask].time, verbose=False)[0]
         else:
-            out = pd.read_csv(f'{fn[:-4]}_preds.csv').dropna(axis=1, how='all') # drop a column if all values are None
+            out = pd.read_csv(f'{fn[:-4]}_preds.csv')
+            for algo in algos: # drop a column if all values are None
+                if algo in out.columns and out[algo].isna().all():
+                    out.drop(algo, axis=1, inplace=True)
+
             # check if everything has already been computed, and if yes skip the file
             if pd.Series(algos).isin(out.columns).all() or (args.quick and pd.Series(quick_algos).isin(out.columns).all()):
-                continue
+               continue
             sig, fs = librosa.load(fn, sr=FS)
 
         out.time *= downsample
@@ -74,23 +90,17 @@ for specie in species if args.specie =='all' else args.specie.split(' '):
             out.loc[mask, 'praat_f0'], out.loc[mask, 'praat_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out.loc[mask, 'time'])
             out.praat_f0 *= downsample
 
-        if not args.quick and not 'pyin_f0' in out.columns: # PYIN
+        if not 'pyin_f0' in out.columns: # PYIN
             f0, voiced, prob = librosa.pyin(sig, sr=fs, fmin=27.5, fmax=fs//2, frame_length=nfft, hop_length=int(fs*dt), center=False)
             out['pyin_f0'], out['pyin_conf'] = f0[:len(out)], prob[:len(out)]
             out.pyin_f0 *= downsample
 
-        if not 'tcrepe_f0' in out.columns: # torch crepe pretrained
+        if not 'tcrepe_f0' in out.columns: # torch crepe out-of-the-box
             time, f0, confidence = run_tcrepe(tcrepe_model, sig, fs, dt)
             mask = ((out.time > time[0])&(out.time < time[-1]))
             out.loc[mask, 'tcrepe_f0'], out.loc[mask, 'tcrepe_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out[mask].time)
             out.tcrepe_f0 *= downsample
 
-        if not 'tcrepe_ft_f0' in out.columns: # torch crepe finetuned on all species
-            time, f0, confidence = run_tcrepe(tcrepe_ft_model, sig, fs, dt)
-            mask = ((out.time > time[0])&(out.time < time[-1]))
-            out.loc[mask, 'tcrepe_ft_f0'], out.loc[mask, 'tcrepe_ft_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out[mask].time)
-            out.tcrepe_ft_f0 *= downsample
-
         if not 'tcrepe_ftsp_f0' in out.columns and tcrepe_ftsp_model: # torch crepe finetuned on the target species
             time, f0, confidence = run_tcrepe(tcrepe_ftsp_model, sig, fs, dt)
             mask = ((out.time > time[0])&(out.time < time[-1]))
@@ -103,7 +113,7 @@ for specie in species if args.specie =='all' else args.specie.split(' '):
             out.loc[mask, 'tcrepe_ftoth_f0'], out.loc[mask, 'tcrepe_ftoth_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out[mask].time)
             out.tcrepe_ftoth_f0 *= downsample
 
-        if not args.quick and not 'crepe_f0' in out.columns: # CREPE
+        if not args.quick and not 'crepe_f0' in out.columns: # CREPE out-of-the-box tensorflow
             time, f0, confidence, activation = crepe.predict(sig, fs, step_size=int(dt*1e3), center=False, verbose=0) # step_size in ms
             out['crepe_f0'], out['crepe_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out.time)
             out.crepe_f0 *= downsample
@@ -116,12 +126,12 @@ for specie in species if args.specie =='all' else args.specie.split(' '):
             out['basic_f0'], out['basic_conf'] = mir_eval.melody.resample_melody_series(time, f0, confidence, out.time)
             out.basic_f0 *= downsample
 
-        if not 'pesto_f0' in out.columns: # pesto
+        if not 'pesto_f0' in out.columns: # pesto out-of-the-box
             try:
                 time, f0, confidence, activation = pesto.predict(torch.tensor(sig).unsqueeze(0), fs, step_size=int(dt*1e3), convert_to_freq=True) # step_size in ms
                 out['pesto_f0'], out['pesto_conf'] = mir_eval.melody.resample_melody_series(time/1000, f0[0], confidence.numpy(), out.time, verbose=False)
                 out.pesto_f0 *= downsample
-            except Exception as inst:
+            except:
                 out['pesto_f0'], out['pesto_conf'] = None, None
 
         if not 'pesto_ft_f0' in out.columns and os.path.isfile(f'pesto_ft/{specie}.pth'): # pesto finetuned
diff --git a/train_crepe.py b/train_crepe.py
index 852608a52b0eb7d5564a746181b59c43820a042e..8af7ee8e8009ba5a66aea96ade33aaa43aba93ae 100644
--- a/train_crepe.py
+++ b/train_crepe.py
@@ -8,9 +8,10 @@ from torch.utils.tensorboard import SummaryWriter
 parser = argparse.ArgumentParser()
 parser.add_argument('--omit', type=str, help="Species to rule out of the training set", default=None)
 parser.add_argument('--only', type=str, help="Train only on the given species", default=None)
+parser.add_argument('--split', type=int, help="Portion out of between 0 and 4 to use as test set", default=0)
 args = parser.parse_args()
 
-suffix = "omit_"+args.omit if args.omit else "only_"+args.only if args.only else "all"
+suffix = "omit_"+args.omit if args.omit else f"only-{args.split}_"+args.only if args.only else "all"
 
 writer = SummaryWriter('runs/'+suffix)
 model = torchcrepe.Crepe('full')
@@ -26,6 +27,8 @@ if not os.path.isfile(f'crepe_ft/train_set_{suffix}.pkl'):
         wavpath, fs, nfft, downsample, step = species[specie].values()
         dt = int(n_in * step) # winsize / 8
         files = glob.glob(wavpath)
+        if args.only:
+            files = files[:int(len(files)/5*args.split)] + files[int(len(files)/5*(args.split+1)):]
         for fn in tqdm.tqdm(pd.Series(files).sample(min(len(files), 1000)), desc='Peparing dataset for '+specie):
             if os.path.isfile(f'noisy_pngs/{fn[:-4]}.png'):
                 continue