Skip to content
Snippets Groups Projects
Commit 32a690f1 authored by Paul Best's avatar Paul Best
Browse files

first fixes

parent 704877d1
Branches
No related tags found
No related merge requests found
......@@ -14,11 +14,13 @@ parser.add_argument('audio_folder', type=str, help='Path of the folder with audi
parser.add_argument('specie', type=str, help='Target specie to detect', choices=['megaptera', 'delphinid', 'orcinus', 'physeter', 'balaenoptera'])
parser.add_argument('pred_fn', type=str, help='Filename for the output table containing model predictions')
parser.add_argument('-lensample', type=float, help='Length of the signal excerpts to process (sec)', default=5),
parser.add_argument('-batchsize', type=int, help='Amount of samples to process at a time', default=32),
parser.add_argument('-maxPool', type=bool, help='Wether to keep only the maximal prediction of a sample or the full sequence', default=True),
parser.add_argument('-batch_size', type=int, help='Amount of samples to process at a time', default=32),
parser.add_argument('-maxPool', help='Wether to keep only the maximal prediction of a sample or the full sequence', action='store_true'),
parser.add_argument('-no-maxPool', dest='maxPool', action='store_false')
parser.set_defaults(maxPool=True)
args = parser.parse_args()
meta_model = {
'delphinid': {
'stdc':'sparrow_dolphin_train8_pcen_conv2d_noaugm_bs32_lr.005_.stdc',
......@@ -50,14 +52,16 @@ def run(folder, stdcfile, model, fs, lensample, batch_size, maxPool):
out = pd.DataFrame(columns=['fn', 'offset', 'pred'])
fns, offsets, preds = [], [], []
loader = data.DataLoader(Dataset(folder, fs, lensample), batch_size=batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4)
if len(loader) == 0:
print('Unable to open any audio file in the given folder')
with torch.no_grad():
for x, meta in tqdm(loader):
x = x.to(device)
pred = model(x).cpu().detach().numpy()
if maxPool:
pred = np.maximum(pred)
pred = pred.max(axis=-1).reshape(len(x))
else:
pred.reshape(len(x), -1)
pred = pred.reshape(len(x), -1)
fns.extend(meta['fn'])
offsets.extend(meta['offset'].numpy())
preds.extend(pred)
......@@ -72,7 +76,7 @@ class Dataset(data.Dataset):
self.samples = []
for fn in os.listdir(folder):
try:
duration = sf.info(folder.fn).duration
duration = sf.info(folder+fn).duration
except:
print(f'Skipping {fn} (unable to read)')
continue
......@@ -97,14 +101,13 @@ class Dataset(data.Dataset):
sig = norm(sig)
return torch.tensor(sig).float(), sample
preds = run(args.audio_folder,
meta_model['stdc'],
models.get[args.specie],
meta_model['fs'],
batch_size=args.batch_size,
lensample=args.lensample,
maxPool=args.maxPool
args.lensample,
args.batch_size,
args.maxPool
)
preds.to_pickle(args.pred_fn)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment