Skip to content
Snippets Groups Projects
Select Git revision
  • 1eb7d6013e9a575d4cbbd431ae4cc66e8fd00962
  • main default protected
2 results

app_functions.py

Blame
  • run_model.py 4.35 KiB
    from models import frontend_logstft, resnet
    from scipy import signal, special
    import soundfile as sf
    import torch
    import numpy as np
    import pandas as pd
    from tqdm import tqdm
    import argparse
    
    parser = argparse.ArgumentParser(description='Script to run the blue and fin whale vocalization detector')
    parser.add_argument('output_fn', type=str, help='filename for the csv containing predictions')
    parser.add_argument('files', type=str, nargs='+', help='list of files to process')
    parser.add_argument('--folder', type=str, help='source folder for the files to process', default='')
    parser.add_argument('--chunklen', type=int, help='size for chunking the sound files (result will give per chunk predictions', default=20)
    parser.add_argument('--batchsize', type=int, help='number of samples to process simultaneously (tune according to available memory space)', default=32)
    args = parser.parse_args()
    
    labels = ['Bp_20Plus', 'Bm_D', 'Bp-Downsweep',  'Bp_20Hz',  'Bm_Ant-A', 'Bm_Ant-Z', 'Bm_Ant-B']
    
    
    def run(files, folder, chunklen, batch_size):
        """ Runs the CNN model over a list of sound files, and returns a table with predictions
        
        Parameters
        ----------
        files : list
            list of string with names of the sound files to process
        folder : str
            path to load the files from
        chunklen : float
            size in second used for chunking the sound files. Ending incomplete chunks won't be processed
        batch_size : int
            number of samples to process simultaneously (tune according to available memory space)
    
        Returns
        -------
        out
            Pandas DataFrame containing the predictions
        """
        # prepare the model
        model = torch.nn.Sequential(frontend_logstft, resnet)
        model.load_state_dict(torch.load('resnet_weights.stdc'))
        model.eval()
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        model.to(device)
        
        # run the model over each batch and store predictions
        fns, offsets, preds = [], [], []
        with torch.no_grad():
            for x, meta in tqdm(torch.utils.data.DataLoader(Dataset(files, folder, chunklen=chunklen), batch_size=batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4)):
                pred = model(x.to(device))
                fns.extend(meta['fn'])
                offsets.extend(meta['offset'].numpy())
                preds.extend(pred.view(len(x), 7).cpu().detach().numpy())
    
        # prepare the output table, and eventually return it
        out = pd.DataFrame(columns=['fn', 'offset']+['pred_'+l for l in labels])
        out.fn, out.offset = fns, offsets
        for l, p in zip(labels, np.array(preds).T):
            out['pred_'+l] = special.expit(p) # expit == sigmoid (normalise predictions between 0 and 1)
        return out
    
    
    
    class Dataset(torch.utils.data.Dataset):
        """
        Class used to load each chunk of signal into batches
        
        Attributes
        ----------
        fns : list
            filenames to process
        folder : str
            path to load the files from
        chunklen : float
            size in second used for chunking the sound files. Ending incomplete chunks won't be processed
        """
        def __init__(self, fns, folder, chunklen):
            super(Dataset, self)
            print('dataset initialisation ...', end='')
            self.samples = np.concatenate([[{'fn':fn, 'offset':offset} for offset in np.arange(0, sf.info(folder+fn).duration-chunklen+.01, chunklen)] for fn in fns])
            print('...done!')
            self.fs, self.folder, self.chunklen = 250, folder, chunklen
    
        def __len__(self):
            return len(self.samples)
    
        def __getitem__(self, idx):
            sample = self.samples[idx]
            fs = sf.info(self.folder+sample['fn']).samplerate
            try:
                sig, fs = sf.read(self.folder+sample['fn'], start=int(sample['offset']*fs), stop=int((sample['offset']+self.chunklen)*fs))
            except:
                print('failed loading '+sample['fn'])
                return None
            sig = sig[:,0] if sig.ndim > 1 else sig
            if fs != self.fs:
                sig = signal.resample(sig, self.chunklen*self.fs)
            return torch.tensor(norm(sig)).float(), sample
    
    
    def norm(arr):
        return (arr - np.mean(arr) ) / np.std(arr)
    
    
    def collate_fn(batch):
        batch = list(filter(lambda x: x is not None, batch))
        return torch.utils.data.dataloader.default_collate(batch) if len(batch) > 0 else None
    
    
    out = run(args.files, args.folder, args.chunklen, args.batchsize)
    out.to_csv(args.output_fn, index=False)