From ba7550d51e87758953b68c086afd66ce21923770 Mon Sep 17 00:00:00 2001
From: Paul Best <paul.best@lis-lab.fr>
Date: Fri, 14 Jan 2022 16:58:23 +0100
Subject: [PATCH] Update run_CNN_HB.py

---
 run_CNN_HB.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/run_CNN_HB.py b/run_CNN_HB.py
index be11e68..3bb7fc2 100644
--- a/run_CNN_HB.py
+++ b/run_CNN_HB.py
@@ -19,7 +19,7 @@ def collate_fn(batch):
     batch = list(filter(lambda x: x is not None, batch))
     return data.dataloader.default_collate(batch) if len(batch) > 0 else None
 
-def run(files, stdcfile, model, folder, fe=44100, pool=False, lensample=5, batch_size=32):
+def run(files, stdcfile, model, folder, pool=False, lensample=5, batch_size=32):
     model.load_state_dict(load(stdcfile))
     model.eval()
     cuda0 = device('cuda' if cuda.is_available() else 'cpu')
@@ -28,7 +28,7 @@ def run(files, stdcfile, model, folder, fe=44100, pool=False, lensample=5, batch
     out = pd.DataFrame(columns=['fn', 'offset', 'pred'])
     fns, offsets, preds = [], [], []
     with no_grad():
-        for x, meta in tqdm(data.DataLoader(Dataset(files, folder, fe=fe, lensample=lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8,prefetch_factor=4)):
+        for x, meta in tqdm(data.DataLoader(Dataset(files, folder, lensample=lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8,prefetch_factor=4)):
             x = x.to(cuda0, non_blocking=True)
             pred = model(x)
             temp = pd.DataFrame().from_dict(meta)
-- 
GitLab