From ba7550d51e87758953b68c086afd66ce21923770 Mon Sep 17 00:00:00 2001 From: Paul Best <paul.best@lis-lab.fr> Date: Fri, 14 Jan 2022 16:58:23 +0100 Subject: [PATCH] Update run_CNN_HB.py --- run_CNN_HB.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/run_CNN_HB.py b/run_CNN_HB.py index be11e68..3bb7fc2 100644 --- a/run_CNN_HB.py +++ b/run_CNN_HB.py @@ -19,7 +19,7 @@ def collate_fn(batch): batch = list(filter(lambda x: x is not None, batch)) return data.dataloader.default_collate(batch) if len(batch) > 0 else None -def run(files, stdcfile, model, folder, fe=44100, pool=False, lensample=5, batch_size=32): +def run(files, stdcfile, model, folder, pool=False, lensample=5, batch_size=32): model.load_state_dict(load(stdcfile)) model.eval() cuda0 = device('cuda' if cuda.is_available() else 'cpu') @@ -28,7 +28,7 @@ def run(files, stdcfile, model, folder, fe=44100, pool=False, lensample=5, batch out = pd.DataFrame(columns=['fn', 'offset', 'pred']) fns, offsets, preds = [], [], [] with no_grad(): - for x, meta in tqdm(data.DataLoader(Dataset(files, folder, fe=fe, lensample=lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8,prefetch_factor=4)): + for x, meta in tqdm(data.DataLoader(Dataset(files, folder, lensample=lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8,prefetch_factor=4)): x = x.to(cuda0, non_blocking=True) pred = model(x) temp = pd.DataFrame().from_dict(meta) -- GitLab