Skip to content
Snippets Groups Projects
Commit 810fb16b authored by Paul Best's avatar Paul Best
Browse files

Update file predict.py

parent c73022bc
No related branches found
No related tags found
No related merge requests found
import matplotlib.pyplot as plt
import argparse, os, tqdm
import torchcrepe, torch, librosa
import torchcrepe, torch, librosa, soundfile
import pandas as pd, numpy as np
parser = argparse.ArgumentParser()
......@@ -9,8 +9,10 @@ parser.add_argument('--model_path', type=str, help="Path of model weights", defa
parser.add_argument('--compress', type=float, help="Compression factor used to shift frequencies into CREPE's range [32Hz; 2kHz]. \
Frequencies are divided by the given factor by artificially changing the sampling rate (slowing down / speeding up the signal).", default=1)
parser.add_argument('--step', type=float, help="Step used between each prediction (in seconds)", default=256 / torchcrepe.SAMPLE_RATE)
parser.add_argument('--decoder', choices=['argmax', 'weighted_argmax', 'viterbi'], help="Decoder used to postprocess predictions", default='viterbi')
parser.add_argument('--print', type=bool, help="Print spectrograms with overlaid F0 predictions to assess their quality", default=False)
parser.add_argument('--decoder', choices=['argmax', 'weighted_argmax', 'viterbi'], help="Decoder used to postprocess predictions", default='weighted_argmax')
parser.add_argument('--print', type=bool, help="Print spectrograms with overlaid F0 predictions to assess their quality", default=True)
parser.add_argument('--threshold', type=float, help="Confidence threshold used when printing F0 predictions on spectrograms ", default=0.1)
parser.add_argument('--NFFT', type=int, help="Window size used for the spectrogram computation (only used for printing F0 predictions)", default=1024)
args = parser.parse_args()
# Initialisations
......@@ -19,7 +21,9 @@ model = torchcrepe.Crepe('full').eval().to(device)
model.load_state_dict(torch.load(args.model_path, map_location=device))
decoder = torchcrepe.decode.__dict__[args.decoder]
for filename in tqdm.tqdm(os.listdir(args.indir)):
files = [fn for fn in os.listdir(args.indir) if fn.split('.')[-1].upper() in soundfile._formats]
for filename in tqdm.tqdm(files):
try:
sig, fs = librosa.load(os.path.join(args.indir, filename), sr=int(torchcrepe.SAMPLE_RATE * args.compress))
except:
......@@ -39,8 +43,11 @@ for filename in tqdm.tqdm(os.listdir(args.indir)):
# Plot F0 predictions over spectrograms
if args.print:
plt.figure(figsize=(max(6.4, 6.4*time[-1]/2), 4.8))
plt.specgram(sig, Fs=fs)
plt.scatter(time, f0, c=confidence)
plt.specgram(sig, Fs=fs, NFFT=args.NFFT, noverlap=args.NFFT-args.NFFT//8)
mask = confidence>args.threshold
plt.scatter(time[mask], f0[mask], c=confidence[mask], s=5)
plt.xlim(0, len(sig)/fs)
plt.ylim(0, f0[mask].max() * 1.5)
plt.colorbar()
plt.tight_layout()
plt.savefig(os.path.join(args.indir, filename.rsplit('.',1)[0])+'_f0.png')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment