Skip to content
Snippets Groups Projects
Commit 88adfba6 authored by Paul Best's avatar Paul Best
Browse files

Upload New File

parent 7908999f
No related branches found
No related tags found
No related merge requests found
from model import HB_model
from scipy import signal
import soundfile as sf
from torch import load, no_grad, tensor, device, cuda
from torch.utils import data
import numpy as np
import pandas as pd
from tqdm import tqdm
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('files', type=str, nargs='+')
parser.add_argument('-outfn', type=str, default='HB_preds.pkl')
args = parser.parse_args()
stdc = 'sparrow_whales_train8C_2610_frontend2_conv1d_noaugm_bs32_lr.05_.stdc'
def collate_fn(batch):
batch = list(filter(lambda x: x is not None, batch))
return data.dataloader.default_collate(batch) if len(batch) > 0 else None
def run(files, stdcfile, model, folder, fe=44100, pool=False, lensample=5, batch_size=32):
model.load_state_dict(load(stdcfile))
model.eval()
cuda0 = device('cuda' if cuda.is_available() else 'cpu')
model.to(cuda0)
out = pd.DataFrame(columns=['fn', 'offset', 'pred'])
fns, offsets, preds = [], [], []
with no_grad():
for x, meta in tqdm(data.DataLoader(Dataset(files, folder, fe=fe, lensample=lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8,prefetch_factor=4)):
x = x.to(cuda0, non_blocking=True)
pred = model(x)
temp = pd.DataFrame().from_dict(meta)
fns.extend(meta['fn'])
offsets.extend(meta['offset'].numpy())
preds.extend(pred.reshape(len(x), -1).cpu().detach().numpy())
# print(meta, temp, pred.reshape(len(x), -1).shape)
# temp['pred'] = pred.reshape(len(x), -1).cpu().detach()
# preds = preds.append(temp, ignore_index=True)
out.fn, out.offset, out.pred = fns, offsets, preds
#preds.pred = preds.pred.apply(np.array)
return out
class Dataset(data.Dataset):
def __init__(self, fns, folder, fe=11025, lenfile=120, lensample=50): # lenfile and lensample in seconds
super(Dataset, self)
print('init dataset')
self.samples = np.concatenate([[{'fn':fn, 'offset':offset} for offset in np.arange(0, sf.info(folder+fn).duration-lensample+1, lensample)] for fn in fns if sf.info(folder+fn).duration>10])
self.lensample = lensample
self.fe, self.folder = fe, folder
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
fs = sf.info(self.folder+sample['fn']).samplerate
try:
sig, fs = sf.read(self.folder+sample['fn'], start=max(0,int(sample['offset']*fs)), stop=int((sample['offset']+self.lensample)*fs))
except:
print('failed loading '+sample['fn'])
return None
if sig.ndim > 1:
sig = sig[:,0]
if len(sig) != fs*self.lensample:
print('to short file '+sample['fn']+' \n'+str(sig.shape))
return None
if fs != self.fe:
sig = signal.resample(sig, self.lensample*self.fe)
sig = norm(sig)
return tensor(sig).float(), sample
def norm(arr):
return (arr - np.mean(arr) ) / np.std(arr)
preds = run(args.files, stdc, HBmodel, './', batch_size=3, lensample=50)
preds.to_pickle(args.outfn)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment