From 88adfba6b2391e2efbc7a3f0876ba0dbdab50a0c Mon Sep 17 00:00:00 2001
From: Paul Best <paul.best@lis-lab.fr>
Date: Tue, 30 Nov 2021 11:04:59 +0100
Subject: [PATCH] Upload New File

---
 run_CNN_HB.py | 81 +++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 81 insertions(+)
 create mode 100644 run_CNN_HB.py

diff --git a/run_CNN_HB.py b/run_CNN_HB.py
new file mode 100644
index 0000000..be11e68
--- /dev/null
+++ b/run_CNN_HB.py
@@ -0,0 +1,81 @@
+from model import HB_model
+from scipy import signal
+import soundfile as sf
+from torch import load, no_grad, tensor, device, cuda
+from torch.utils import data
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+import argparse
+
+parser = argparse.ArgumentParser()
+parser.add_argument('files', type=str, nargs='+')
+parser.add_argument('-outfn', type=str, default='HB_preds.pkl')
+args = parser.parse_args()
+
+stdc = 'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc'
+
+def collate_fn(batch):
+    batch = list(filter(lambda x: x is not None, batch))
+    return data.dataloader.default_collate(batch) if len(batch) > 0 else None
+
+def run(files, stdcfile, model, folder, fe=44100, pool=False, lensample=5, batch_size=32):
+    model.load_state_dict(load(stdcfile))
+    model.eval()
+    cuda0 = device('cuda' if cuda.is_available() else 'cpu')
+    model.to(cuda0)
+
+    out = pd.DataFrame(columns=['fn', 'offset', 'pred'])
+    fns, offsets, preds = [], [], []
+    with no_grad():
+        for x, meta in tqdm(data.DataLoader(Dataset(files, folder, fe=fe, lensample=lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8,prefetch_factor=4)):
+            x = x.to(cuda0, non_blocking=True)
+            pred = model(x)
+            temp = pd.DataFrame().from_dict(meta)
+            fns.extend(meta['fn'])
+            offsets.extend(meta['offset'].numpy())
+            preds.extend(pred.reshape(len(x), -1).cpu().detach().numpy())
+#            print(meta, temp, pred.reshape(len(x), -1).shape)
+#            temp['pred'] = pred.reshape(len(x), -1).cpu().detach()
+#            preds = preds.append(temp, ignore_index=True)
+    out.fn, out.offset, out.pred = fns, offsets, preds
+    #preds.pred = preds.pred.apply(np.array)
+    return out
+
+
+
+class Dataset(data.Dataset):
+    def __init__(self, fns, folder, fe=11025, lenfile=120, lensample=50): # lenfile and lensample in seconds
+        super(Dataset, self)
+        print('init dataset')
+        self.samples = np.concatenate([[{'fn':fn, 'offset':offset} for offset in np.arange(0, sf.info(folder+fn).duration-lensample+1, lensample)] for fn in fns if sf.info(folder+fn).duration>10])
+        self.lensample = lensample
+        self.fe, self.folder = fe, folder
+
+    def __len__(self):
+        return len(self.samples)
+
+    def __getitem__(self, idx):
+        sample = self.samples[idx]
+        fs = sf.info(self.folder+sample['fn']).samplerate
+        try:
+            sig, fs = sf.read(self.folder+sample['fn'], start=max(0,int(sample['offset']*fs)), stop=int((sample['offset']+self.lensample)*fs))
+        except:
+            print('failed loading '+sample['fn'])
+            return None
+        if sig.ndim > 1:
+            sig = sig[:,0]
+        if len(sig) != fs*self.lensample:
+            print('to short file '+sample['fn']+' \n'+str(sig.shape))
+            return None
+        if fs != self.fe:
+            sig = signal.resample(sig, self.lensample*self.fe)
+
+        sig = norm(sig)
+        return tensor(sig).float(), sample
+
+def norm(arr):
+    return (arr - np.mean(arr) ) / np.std(arr)
+
+preds = run(args.files, stdc, HBmodel, './', batch_size=3, lensample=50)
+preds.to_pickle(args.outfn)
-- 
GitLab