Skip to content
Snippets Groups Projects
Commit ba7550d5 authored by Paul Best's avatar Paul Best
Browse files

Update run_CNN_HB.py

parent 1377bc0e
No related branches found
No related tags found
No related merge requests found
...@@ -19,7 +19,7 @@ def collate_fn(batch): ...@@ -19,7 +19,7 @@ def collate_fn(batch):
batch = list(filter(lambda x: x is not None, batch)) batch = list(filter(lambda x: x is not None, batch))
return data.dataloader.default_collate(batch) if len(batch) > 0 else None return data.dataloader.default_collate(batch) if len(batch) > 0 else None
def run(files, stdcfile, model, folder, fe=44100, pool=False, lensample=5, batch_size=32): def run(files, stdcfile, model, folder, pool=False, lensample=5, batch_size=32):
model.load_state_dict(load(stdcfile)) model.load_state_dict(load(stdcfile))
model.eval() model.eval()
cuda0 = device('cuda' if cuda.is_available() else 'cpu') cuda0 = device('cuda' if cuda.is_available() else 'cpu')
...@@ -28,7 +28,7 @@ def run(files, stdcfile, model, folder, fe=44100, pool=False, lensample=5, batch ...@@ -28,7 +28,7 @@ def run(files, stdcfile, model, folder, fe=44100, pool=False, lensample=5, batch
out = pd.DataFrame(columns=['fn', 'offset', 'pred']) out = pd.DataFrame(columns=['fn', 'offset', 'pred'])
fns, offsets, preds = [], [], [] fns, offsets, preds = [], [], []
with no_grad(): with no_grad():
for x, meta in tqdm(data.DataLoader(Dataset(files, folder, fe=fe, lensample=lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8,prefetch_factor=4)): for x, meta in tqdm(data.DataLoader(Dataset(files, folder, lensample=lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8,prefetch_factor=4)):
x = x.to(cuda0, non_blocking=True) x = x.to(cuda0, non_blocking=True)
pred = model(x) pred = model(x)
temp = pd.DataFrame().from_dict(meta) temp = pd.DataFrame().from_dict(meta)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment