import os import torch import models from scipy import signal, special import soundfile as sf import numpy as np import pandas as pd from tqdm import tqdm import argparse parser = argparse.ArgumentParser(description="Run this script to use a CNN for the detection of cetacean vocalizations on a folder of audio files.") parser.add_argument('audio_folder', type=str, help='Path of the folder with audio files to process') parser.add_argument('specie', type=str, help='Target specie to detect', choices=['megaptera', 'delphinid', 'orcinus', 'physeter', 'balaenoptera', 'globicephala']) parser.add_argument('-lensample', type=float, help='Length of the signal for each sample (in seconds)', default=5), parser.add_argument('-batch_size', type=int, help='Amount of samples to process at a time (usefull for parallel computation using a GPU)', default=32), parser.add_argument('-channel', type=int, help='Channel of the audio file to use in the model inference (starting from 0)', default=0) parser.add_argument('-maxPool', help='Wether to keep only the maximal prediction of each sample or the full sequence', action='store_true'), parser.add_argument('-no-maxPool', dest='maxPool', action='store_false') parser.add_argument('-output_filename', type=str, help='Name of the output file for saving predictions', default='') parser.set_defaults(maxPool=True) args = parser.parse_args() def collate_fn(batch): batch = list(filter(lambda x: x is not None, batch)) return torch.utils.data.dataloader.default_collate(batch) if len(batch) > 0 else None norm = lambda arr: (arr - np.mean(arr) ) / np.std(arr) # Pytorch dataset class to load audio samples class Dataset(torch.utils.data.Dataset): def __init__(self): super(Dataset, self) self.samples = [] for fn in tqdm(os.listdir(args.audio_folder), desc='Dataset initialization', leave=False): try: info = sf.info(os.path.join(args.audio_folder, fn)) duration, fs = info.duration, info.samplerate self.samples.extend([{'fn':fn, 'offset':offset, 'fs':fs} for offset in np.arange(0, duration+1e-10 - args.lensample, args.lensample)]) except: continue assert info.channels > args.channel, f"The desired channel is unavailable for the audio file {fn}" def __len__(self): return len(self.samples) def __getitem__(self, idx): sample = self.samples[idx] try: sig, fs = sf.read(os.path.join(args.audio_folder, sample['fn']), start=int(sample['offset']*sample['fs']), stop=int((sample['offset']+args.lensample)*sample['fs']), always_2d=True) except: print('Failed loading '+sample['fn']) return None sig = sig[:, args.channel] if fs != models.get[args.specie]['fs']: sig = signal.resample(sig, int(args.lensample * models.get[args.specie]['fs'])) sig = norm(sig) return torch.tensor(sig).float(), sample # prepare model model = models.get[args.specie]['archi'] model.load_state_dict(torch.load(f".{os.path.dirname(__file__)}/weights/{models.get[args.specie]['weights']}")) model.eval() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) # prepare data loader and output storage for predictions loader = torch.utils.data.DataLoader(Dataset(), batch_size=args.batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4) if len(loader) == 0: print(f'Unable to open any audio file in the given folder {args.audio_folder}') exit() out = pd.DataFrame(columns=['filename', 'offset', 'prediction']) fns, offsets, preds = [], [], [] # forward the model on each batch with torch.no_grad(): for x, meta in tqdm(loader, desc='Model inference'): x = x.to(device) pred = special.expit(model(x).cpu().detach().numpy()) if args.maxPool: pred = pred.max(axis=-1).reshape(len(x)) else: pred = pred.reshape(len(x), -1) preds.extend(pred) fns.extend(meta['fn']) offsets.extend(meta['offset'].numpy()) out.filename, out.offset, out.prediction = fns, offsets, preds pred_fn = list(filter(lambda e: e!='', args.audio_folder.split('/')))[-1] + ('.csv' if args.maxPool else '.pkl') if args.output_filename == '' else args.output_filename print(f'Saving results into {pred_fn}') if pred_fn.endswith('csv'): out.to_csv(pred_fn, index=False) else: out.to_pickle(pred_fn)