Select Git revision
print_detections.py
print_detections.py 2.00 KiB
import os, argparse
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd, numpy as np
import models, utils as u
import torch
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Plot vocalisation spectrograms into annot_pngs/")
parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation")
parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)")
parser.set_defaults(feature=False)
args = parser.parse_args()
frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
df = pd.read_csv(args.detections)
os.system('rm detections_pngs/*')
loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0)), batch_size=1, num_workers=8, collate_fn=u.collate_fn, shuffle=True)
for x, idx in tqdm(loader):
x = frontend(x).squeeze().detach()
plt.imshow(x, origin='lower', aspect='auto', vmin=torch.quantile(x, .25), cmap='Greys', vmax=torch.quantile(x, .98))
plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
plt.savefig(f'detections_pngs/{idx.item()}')
plt.close()