From 88adfba6b2391e2efbc7a3f0876ba0dbdab50a0c Mon Sep 17 00:00:00 2001 From: Paul Best <paul.best@lis-lab.fr> Date: Tue, 30 Nov 2021 11:04:59 +0100 Subject: [PATCH] Upload New File --- run_CNN_HB.py | 81 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 run_CNN_HB.py diff --git a/run_CNN_HB.py b/run_CNN_HB.py new file mode 100644 index 0000000..be11e68 --- /dev/null +++ b/run_CNN_HB.py @@ -0,0 +1,81 @@ +from model import HB_model +from scipy import signal +import soundfile as sf +from torch import load, no_grad, tensor, device, cuda +from torch.utils import data +import numpy as np +import pandas as pd +from tqdm import tqdm +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('files', type=str, nargs='+') +parser.add_argument('-outfn', type=str, default='HB_preds.pkl') +args = parser.parse_args() + +stdc = 'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc' + +def collate_fn(batch): + batch = list(filter(lambda x: x is not None, batch)) + return data.dataloader.default_collate(batch) if len(batch) > 0 else None + +def run(files, stdcfile, model, folder, fe=44100, pool=False, lensample=5, batch_size=32): + model.load_state_dict(load(stdcfile)) + model.eval() + cuda0 = device('cuda' if cuda.is_available() else 'cpu') + model.to(cuda0) + + out = pd.DataFrame(columns=['fn', 'offset', 'pred']) + fns, offsets, preds = [], [], [] + with no_grad(): + for x, meta in tqdm(data.DataLoader(Dataset(files, folder, fe=fe, lensample=lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8,prefetch_factor=4)): + x = x.to(cuda0, non_blocking=True) + pred = model(x) + temp = pd.DataFrame().from_dict(meta) + fns.extend(meta['fn']) + offsets.extend(meta['offset'].numpy()) + preds.extend(pred.reshape(len(x), -1).cpu().detach().numpy()) +# print(meta, temp, pred.reshape(len(x), -1).shape) +# temp['pred'] = pred.reshape(len(x), -1).cpu().detach() +# preds = preds.append(temp, ignore_index=True) + out.fn, out.offset, out.pred = fns, offsets, preds + #preds.pred = preds.pred.apply(np.array) + return out + + + +class Dataset(data.Dataset): + def __init__(self, fns, folder, fe=11025, lenfile=120, lensample=50): # lenfile and lensample in seconds + super(Dataset, self) + print('init dataset') + self.samples = np.concatenate([[{'fn':fn, 'offset':offset} for offset in np.arange(0, sf.info(folder+fn).duration-lensample+1, lensample)] for fn in fns if sf.info(folder+fn).duration>10]) + self.lensample = lensample + self.fe, self.folder = fe, folder + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + sample = self.samples[idx] + fs = sf.info(self.folder+sample['fn']).samplerate + try: + sig, fs = sf.read(self.folder+sample['fn'], start=max(0,int(sample['offset']*fs)), stop=int((sample['offset']+self.lensample)*fs)) + except: + print('failed loading '+sample['fn']) + return None + if sig.ndim > 1: + sig = sig[:,0] + if len(sig) != fs*self.lensample: + print('to short file '+sample['fn']+' \n'+str(sig.shape)) + return None + if fs != self.fe: + sig = signal.resample(sig, self.lensample*self.fe) + + sig = norm(sig) + return tensor(sig).float(), sample + +def norm(arr): + return (arr - np.mean(arr) ) / np.std(arr) + +preds = run(args.files, stdc, HBmodel, './', batch_size=3, lensample=50) +preds.to_pickle(args.outfn) -- GitLab