Skip to content
Snippets Groups Projects
Select Git revision
  • a6f7fa15474f2cf759f79e2d1d4df67777ed74a2
  • main default protected
2 results

print_detections.py

Blame
  • print_detections.py 2.00 KiB
    import os, argparse
    from tqdm import tqdm
    import matplotlib.pyplot as plt
    import pandas as pd, numpy as np
    import models, utils as u
    import torch
    
    
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="Plot vocalisation spectrograms into annot_pngs/")
    parser.add_argument("detections", type=str, help=".csv file with detections to be encoded. Columns filename (path of the soundfile) and pos (center of the detection in seconds) are needed")
    parser.add_argument("-audio_folder", type=str, default='./', help="Folder from which to load sound files")
    parser.add_argument("-NFFT", type=int, default=1024, help="FFT size for the spectrogram computation")
    parser.add_argument("-nMel", type=int, default=128, help="Number of Mel bands for the spectrogram (either 64 or 128)")
    parser.add_argument("-SR", type=int, default=44100, help="Sample rate of the samples before spectrogram computation")
    parser.add_argument("-sampleDur", type=float, default=1, help="Size of the signal extracts surrounding detections to be encoded")
    parser.add_argument('-medfilt', action='store_true', help="If a frequency-wise median filter is desired (a larger sampleDur will be used only for a better median estimation)")
    parser.set_defaults(feature=False)
    args = parser.parse_args()
    
    
    frontend = models.frontend_medfilt(args.SR, args.NFFT, args.sampleDur, args.nMel) if args.medfilt else models.frontend(args.SR, args.NFFT, args.sampleDur, args.nMel)
    df = pd.read_csv(args.detections)
    
    os.system('rm detections_pngs/*')
    loader = torch.utils.data.DataLoader(u.Dataset(df, args.audio_folder, args.SR, args.sampleDur + (2 if args.medfilt else 0)), batch_size=1, num_workers=8, collate_fn=u.collate_fn, shuffle=True)
    
    for x, idx in tqdm(loader):
        x = frontend(x).squeeze().detach()
        plt.imshow(x, origin='lower', aspect='auto', vmin=torch.quantile(x, .25), cmap='Greys', vmax=torch.quantile(x, .98))
        plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
        plt.savefig(f'detections_pngs/{idx.item()}')
        plt.close()