From 810fb16b95dee080ddf7d108231d0f387f812eee Mon Sep 17 00:00:00 2001
From: Paul Best <paul.best@lis-lab.fr>
Date: Mon, 2 Sep 2024 16:28:12 +0200
Subject: [PATCH] Update file predict.py

---
 predict.py | 19 +++++++++++++------
 1 file changed, 13 insertions(+), 6 deletions(-)

diff --git a/predict.py b/predict.py
index 6d3207c..0ef6400 100644
--- a/predict.py
+++ b/predict.py
@@ -1,6 +1,6 @@
 import matplotlib.pyplot as plt
 import argparse, os, tqdm
-import torchcrepe, torch, librosa
+import torchcrepe, torch, librosa, soundfile
 import pandas as pd, numpy as np
 
 parser = argparse.ArgumentParser()
@@ -9,8 +9,10 @@ parser.add_argument('--model_path', type=str, help="Path of model weights", defa
 parser.add_argument('--compress', type=float, help="Compression factor used to shift frequencies into CREPE's range [32Hz; 2kHz]. \
     Frequencies are divided by the given factor by artificially changing the sampling rate (slowing down / speeding up the signal).", default=1)
 parser.add_argument('--step', type=float, help="Step used between each prediction (in seconds)", default=256 / torchcrepe.SAMPLE_RATE)
-parser.add_argument('--decoder', choices=['argmax', 'weighted_argmax', 'viterbi'], help="Decoder used to postprocess predictions", default='viterbi')
-parser.add_argument('--print', type=bool, help="Print spectrograms with overlaid F0 predictions to assess their quality", default=False)
+parser.add_argument('--decoder', choices=['argmax', 'weighted_argmax', 'viterbi'], help="Decoder used to postprocess predictions", default='weighted_argmax')
+parser.add_argument('--print', type=bool, help="Print spectrograms with overlaid F0 predictions to assess their quality", default=True)
+parser.add_argument('--threshold', type=float, help="Confidence threshold used when printing F0 predictions on spectrograms ", default=0.1)
+parser.add_argument('--NFFT', type=int, help="Window size used for the spectrogram computation (only used for printing F0 predictions)", default=1024)
 args = parser.parse_args()
 
 # Initialisations
@@ -19,7 +21,9 @@ model = torchcrepe.Crepe('full').eval().to(device)
 model.load_state_dict(torch.load(args.model_path, map_location=device))
 decoder = torchcrepe.decode.__dict__[args.decoder]
 
-for filename in tqdm.tqdm(os.listdir(args.indir)):
+files = [fn for fn in os.listdir(args.indir) if fn.split('.')[-1].upper() in soundfile._formats]
+
+for filename in tqdm.tqdm(files):
     try:
         sig, fs = librosa.load(os.path.join(args.indir, filename), sr=int(torchcrepe.SAMPLE_RATE * args.compress))
     except:
@@ -39,8 +43,11 @@ for filename in tqdm.tqdm(os.listdir(args.indir)):
     # Plot F0 predictions over spectrograms
     if args.print:
         plt.figure(figsize=(max(6.4, 6.4*time[-1]/2), 4.8))
-        plt.specgram(sig, Fs=fs)
-        plt.scatter(time, f0, c=confidence)
+        plt.specgram(sig, Fs=fs, NFFT=args.NFFT, noverlap=args.NFFT-args.NFFT//8)
+        mask = confidence>args.threshold
+        plt.scatter(time[mask], f0[mask], c=confidence[mask], s=5)
+        plt.xlim(0, len(sig)/fs)
+        plt.ylim(0, f0[mask].max() * 1.5)
         plt.colorbar()
         plt.tight_layout()
         plt.savefig(os.path.join(args.indir, filename.rsplit('.',1)[0])+'_f0.png')
-- 
GitLab