Skip to content
Snippets Groups Projects
Commit a611e8f9 authored by Paul Best's avatar Paul Best
Browse files

change argument managment

parent 2ee1b3f3
No related branches found
No related tags found
No related merge requests found
...@@ -27,15 +27,14 @@ norm = lambda arr: (arr - np.mean(arr) ) / np.std(arr) ...@@ -27,15 +27,14 @@ norm = lambda arr: (arr - np.mean(arr) ) / np.std(arr)
# Pytorch dataset class to load audio samples # Pytorch dataset class to load audio samples
class Dataset(torch.utils.data.Dataset): class Dataset(torch.utils.data.Dataset):
def __init__(self, folder, fs, lensample): def __init__(self):
super(Dataset, self) super(Dataset, self)
self.fs, self.folder, self.lensample = fs, folder, lensample
self.samples = [] self.samples = []
for fn in tqdm(os.listdir(folder), desc='Dataset initialization', leave=False): for fn in tqdm(os.listdir(args.audio_folder), desc='Dataset initialization', leave=False):
try: try:
info = sf.info(folder+fn) info = sf.info(os.path.join(args.audio_folder, fn))
duration, fs = info.duration, info.samplerate duration, fs = info.duration, info.samplerate
self.samples.extend([{'fn':fn, 'offset':offset, 'fs':fs} for offset in np.arange(0, duration+.01-lensample, lensample)]) self.samples.extend([{'fn':fn, 'offset':offset, 'fs':fs} for offset in np.arange(0, duration+.01 - args.lensample, args.lensample)])
except: except:
continue continue
assert info.channels > args.channel, f"The desired channel is unavailable for the audio file {fn}" assert info.channels > args.channel, f"The desired channel is unavailable for the audio file {fn}"
...@@ -45,29 +44,29 @@ class Dataset(torch.utils.data.Dataset): ...@@ -45,29 +44,29 @@ class Dataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
sample = self.samples[idx] sample = self.samples[idx]
try: try:
sig, fs = sf.read(self.folder+sample['fn'], start=int(sample['offset']*sample['fs']), stop=int((sample['offset']+self.lensample)*sample['fs']), always_2d=True) sig, fs = sf.read(os.path.join(args.audio_folder, sample['fn']), start=int(sample['offset']*sample['fs']), stop=int((sample['offset']+args.lensample)*sample['fs']), always_2d=True)
except: except:
print('Failed loading '+sample['fn']) print('Failed loading '+sample['fn'])
return None return None
sig = sig[:, args.channel] sig = sig[:, args.channel]
if fs != self.fs: if fs != models.get[args.specie]['fs']:
sig = signal.resample(sig, self.lensample*self.fs) sig = signal.resample(sig, args.lensample * models.get[args.specie]['fs'])
sig = norm(sig) sig = norm(sig)
return torch.tensor(sig).float(), sample return torch.tensor(sig).float(), sample
# prepare model # prepare model
model = models.get[args.specie]['archi'] model = models.get[args.specie]['archi']
model.load_state_dict(torch.load(f"weights/{models.get[args.specie]['weights']}")) model.load_state_dict(torch.load(f"{os.path.dirname(__file__)}/weights/{models.get[args.specie]['weights']}"))
model.eval() model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device) model.to(device)
# prepare data loader and output storage for predictions # prepare data loader and output storage for predictions
loader = torch.utils.data.DataLoader(Dataset(args.audio_folder, models.get[args.specie]['fs'], args.lensample), loader = torch.utils.data.DataLoader(Dataset(),
batch_size=args.batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4) batch_size=args.batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4)
if len(loader) == 0: if len(loader) == 0:
print(f'Unable to open any audio file in the given folder {args.audiofolder}') print(f'Unable to open any audio file in the given folder {args.audio_folder}')
exit() exit()
out = pd.DataFrame(columns=['filename', 'offset', 'prediction']) out = pd.DataFrame(columns=['filename', 'offset', 'prediction'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment