Skip to content
Snippets Groups Projects
Select Git revision
  • c01089af443c514116ce6036346bdd66c2a33ffb
  • master default
  • object
  • develop protected
  • private_algos
  • cuisine
  • SMOTE
  • revert-76c4cca5
  • archive protected
  • no_graphviz
  • 0.0.1
11 results

ColorHistoTest.ipynb

Blame
  • run_CNN_HB.py 3.10 KiB
    from model import HB_model
    from scipy import signal
    import soundfile as sf
    from torch import load, no_grad, tensor, device, cuda
    from torch.utils import data
    import numpy as np
    import pandas as pd
    from tqdm import tqdm
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument('files', type=str, nargs='+')
    parser.add_argument('-outfn', type=str, default='HB_preds.pkl')
    args = parser.parse_args()
    
    stdc = 'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc'
    
    def collate_fn(batch):
        batch = list(filter(lambda x: x is not None, batch))
        return data.dataloader.default_collate(batch) if len(batch) > 0 else None
    
    def run(files, stdcfile, model, folder, fe=44100, pool=False, lensample=5, batch_size=32):
        model.load_state_dict(load(stdcfile))
        model.eval()
        cuda0 = device('cuda' if cuda.is_available() else 'cpu')
        model.to(cuda0)
    
        out = pd.DataFrame(columns=['fn', 'offset', 'pred'])
        fns, offsets, preds = [], [], []
        with no_grad():
            for x, meta in tqdm(data.DataLoader(Dataset(files, folder, fe=fe, lensample=lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8,prefetch_factor=4)):
                x = x.to(cuda0, non_blocking=True)
                pred = model(x)
                temp = pd.DataFrame().from_dict(meta)
                fns.extend(meta['fn'])
                offsets.extend(meta['offset'].numpy())
                preds.extend(pred.reshape(len(x), -1).cpu().detach().numpy())
    #            print(meta, temp, pred.reshape(len(x), -1).shape)
    #            temp['pred'] = pred.reshape(len(x), -1).cpu().detach()
    #            preds = preds.append(temp, ignore_index=True)
        out.fn, out.offset, out.pred = fns, offsets, preds
        #preds.pred = preds.pred.apply(np.array)
        return out
    
    
    
    class Dataset(data.Dataset):
        def __init__(self, fns, folder, fe=11025, lenfile=120, lensample=50): # lenfile and lensample in seconds
            super(Dataset, self)
            print('init dataset')
            self.samples = np.concatenate([[{'fn':fn, 'offset':offset} for offset in np.arange(0, sf.info(folder+fn).duration-lensample+1, lensample)] for fn in fns if sf.info(folder+fn).duration>10])
            self.lensample = lensample
            self.fe, self.folder = fe, folder
    
        def __len__(self):
            return len(self.samples)
    
        def __getitem__(self, idx):
            sample = self.samples[idx]
            fs = sf.info(self.folder+sample['fn']).samplerate
            try:
                sig, fs = sf.read(self.folder+sample['fn'], start=max(0,int(sample['offset']*fs)), stop=int((sample['offset']+self.lensample)*fs))
            except:
                print('failed loading '+sample['fn'])
                return None
            if sig.ndim > 1:
                sig = sig[:,0]
            if len(sig) != fs*self.lensample:
                print('to short file '+sample['fn']+' \n'+str(sig.shape))
                return None
            if fs != self.fe:
                sig = signal.resample(sig, self.lensample*self.fe)
    
            sig = norm(sig)
            return tensor(sig).float(), sample
    
    def norm(arr):
        return (arr - np.mean(arr) ) / np.std(arr)
    
    preds = run(args.files, stdc, HBmodel, './', batch_size=3, lensample=50)
    preds.to_pickle(args.outfn)