Skip to content
Snippets Groups Projects
Commit 062ea49d authored by Paul Best's avatar Paul Best
Browse files

small fixes

parent 50858fba
No related branches found
No related tags found
No related merge requests found
...@@ -11,13 +11,13 @@ from tqdm import tqdm ...@@ -11,13 +11,13 @@ from tqdm import tqdm
cent_thr = 50 cent_thr = 50
metrics = ['recall', 'FA', 'pitch_acc', 'chroma_acc', 'diff_distrib'] metrics = ['recall', 'FA', 'pitch_acc', 'chroma_acc', 'diff_distrib']
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('specie', type=str, help="Species to treat specifically", default=None) parser.add_argument('specie', type=str, help="Species to treat specifically", default='all')
parser.add_argument('--drop_noisy_bins', type=bool, help="drop noisy vocalisations", default=False) parser.add_argument('--drop_noisy_bins', type=bool, help="drop noisy vocalisations", default=False)
parser.add_argument('--drop_noisy_vocs', type=bool, help="drop noisy STFT bins", default=False) parser.add_argument('--drop_noisy_vocs', type=bool, help="drop noisy STFT bins", default=False)
args = parser.parse_args() args = parser.parse_args()
for specie in species if args.specie is None else args.specie.split(' '): for specie in species if args.specie=='all' else args.specie.split(' '):
algos = {'pyin', 'praat', 'crepe', 'tcrepe', 'tcrepe_ft', 'basic', 'pesto', 'tcrepe_ftsp', 'pesto_ft'} algos = {'pyin', 'praat', 'crepe', 'tcrepe', 'tcrepe_ft', 'tcrepe_ftsp', 'tcrepe_ftoth', 'basic', 'pesto', 'pesto_ft'}
# Get optimal thresholds # Get optimal thresholds
confs = {k:[] for k in algos} confs = {k:[] for k in algos}
confs['label'] = [] confs['label'] = []
......
...@@ -12,8 +12,8 @@ parser.add_argument('specie', type=str, help="Species to treat specifically", de ...@@ -12,8 +12,8 @@ parser.add_argument('specie', type=str, help="Species to treat specifically", de
args = parser.parse_args() args = parser.parse_args()
for specie in species if args.specie is None else [args.specie]: for specie in species if args.specie is None else [args.specie]:
wavpath, FS, nfft, downsample = species[specie].values() wavpath, FS, nfft, downsample, step = species[specie].values()
dt = nfft / 8 / FS # winsize / 8 dt = nfft * step / FS # winsize / 8
Hz2bin = lambda f: np.round(f/FS*nfft).astype(int) Hz2bin = lambda f: np.round(f/FS*nfft).astype(int)
# for fn in glob(wavpath): # for fn in glob(wavpath):
def fun(fn): def fun(fn):
......
...@@ -7,11 +7,11 @@ import crepe, pesto, tensorflow as tf, torchcrepe, torch, basic_pitch.inference, ...@@ -7,11 +7,11 @@ import crepe, pesto, tensorflow as tf, torchcrepe, torch, basic_pitch.inference,
basic_pitch_model = tf.saved_model.load(str(basic_pitch.ICASSP_2022_MODEL_PATH)) basic_pitch_model = tf.saved_model.load(str(basic_pitch.ICASSP_2022_MODEL_PATH))
tcrepe_model = torchcrepe.Crepe('full').eval().to('cuda:1') tcrepe_model = torchcrepe.Crepe('full').eval().to('cuda')
tcrepe_model.load_state_dict(torch.load('/home/paul.best/.local/lib/python3.9/site-packages/torchcrepe/assets/full.pth', map_location='cuda:1')) tcrepe_model.load_state_dict(torch.load('/home/paul.best/.local/lib/python3.9/site-packages/torchcrepe/assets/full.pth', map_location='cuda'))
tcrepe_ft_model = torchcrepe.Crepe('full').eval().to('cuda:1') tcrepe_ft_model = torchcrepe.Crepe('full').eval().to('cuda')
tcrepe_ft_model.load_state_dict(torch.load('crepe_ft/model_all.pth', map_location='cuda:1')) tcrepe_ft_model.load_state_dict(torch.load('crepe_ft/model_all.pth', map_location='cuda'))
cents_mapping = np.linspace(0, 7180, 360) + 1997.3794084376191 cents_mapping = np.linspace(0, 7180, 360) + 1997.3794084376191
...@@ -33,20 +33,20 @@ parser = argparse.ArgumentParser() ...@@ -33,20 +33,20 @@ parser = argparse.ArgumentParser()
parser.add_argument('specie', type=str, help="Species to treat specifically", default=None) parser.add_argument('specie', type=str, help="Species to treat specifically", default=None)
args = parser.parse_args() args = parser.parse_args()
for specie in species if args.specie is None else args.specie.split(' '): for specie in species if args.specie =='all' else args.specie.split(' '):
wavpath, FS, nfft, downsample = species[specie].values() wavpath, FS, nfft, downsample, step = species[specie].values()
dt = round(nfft / 8 / FS * downsample, 3) # winsize / 8 dt = round(nfft * step / FS * downsample, 3) # winsize / 8
tcrepe_ftoth_model, tcrepe_ftsp_model = None, None tcrepe_ftoth_model, tcrepe_ftsp_model = None, None
if os.path.isfile(f'crepe_ft/model_only_{specie}.pth'): if os.path.isfile(f'crepe_ft/model_only_{specie}.pth'):
tcrepe_ftsp_model = torchcrepe.Crepe('full').eval().to('cuda:1') tcrepe_ftsp_model = torchcrepe.Crepe('full').eval().to('cuda')
tcrepe_ftsp_model.load_state_dict(torch.load(f'crepe_ft/model_only_{specie}.pth', map_location='cuda:1')) tcrepe_ftsp_model.load_state_dict(torch.load(f'crepe_ft/model_only_{specie}.pth', map_location='cuda'))
if os.path.isfile(f'crepe_ft/model_omit_{specie}.pth'): if os.path.isfile(f'crepe_ft/model_omit_{specie}.pth'):
tcrepe_ftoth_model = torchcrepe.Crepe('full').eval().to('cuda:1') tcrepe_ftoth_model = torchcrepe.Crepe('full').eval().to('cuda')
tcrepe_ftoth_model.load_state_dict(torch.load(f'crepe_ft/model_omit_{specie}.pth', map_location='cuda:1')) tcrepe_ftoth_model.load_state_dict(torch.load(f'crepe_ft/model_omit_{specie}.pth', map_location='cuda'))
for fn in tqdm(glob(wavpath), desc=specie): for fn in tqdm(glob(wavpath), desc=specie):
if overwrite or not os.path.isfile(f'{fn[:-4]}_preds.csv'): if overwrite or not os.path.isfile(f'{fn[:-4]}_preds.csv') or os.path.getsize(f'{fn[:-4]}_preds.csv') < 300:
annot = pd.read_csv(f'{fn[:-4]}.csv').drop_duplicates(subset='Time') annot = pd.read_csv(f'{fn[:-4]}.csv').drop_duplicates(subset='Time')
sig, fs = librosa.load(fn, sr=FS) sig, fs = librosa.load(fn, sr=FS)
out = pd.DataFrame({'time':np.arange(nfft/fs/2, (len(sig) - nfft/2)/fs, dt / downsample)}) out = pd.DataFrame({'time':np.arange(nfft/fs/2, (len(sig) - nfft/2)/fs, dt / downsample)})
...@@ -59,7 +59,6 @@ for specie in species if args.specie is None else args.specie.split(' '): ...@@ -59,7 +59,6 @@ for specie in species if args.specie is None else args.specie.split(' '):
continue continue
sig, fs = librosa.load(fn, sr=FS) sig, fs = librosa.load(fn, sr=FS)
out.annot /= downsample
out.time *= downsample out.time *= downsample
fs /= downsample fs /= downsample
...@@ -128,6 +127,5 @@ for specie in species if args.specie is None else args.specie.split(' '): ...@@ -128,6 +127,5 @@ for specie in species if args.specie is None else args.specie.split(' '):
except Exception as inst: except Exception as inst:
out['pesto_ft_f0'], out['pesto_ft_conf'] = None, None out['pesto_ft_f0'], out['pesto_ft_conf'] = None, None
out.annot *= downsample
out.time /= downsample out.time /= downsample
out.to_csv(f'{fn[:-4]}_preds.csv', index=False) out.to_csv(f'{fn[:-4]}_preds.csv', index=False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment