diff --git a/run_CNN.py b/run_CNN.py index d86a6103c7ee7087a2e95a4eab70cc2e397b477c..3f5c2b56a35c635a6a49f18058efc64742c1bc4b 100644 --- a/run_CNN.py +++ b/run_CNN.py @@ -27,15 +27,14 @@ norm = lambda arr: (arr - np.mean(arr) ) / np.std(arr) # Pytorch dataset class to load audio samples class Dataset(torch.utils.data.Dataset): - def __init__(self, folder, fs, lensample): + def __init__(self): super(Dataset, self) - self.fs, self.folder, self.lensample = fs, folder, lensample self.samples = [] - for fn in tqdm(os.listdir(folder), desc='Dataset initialization', leave=False): + for fn in tqdm(os.listdir(args.audio_folder), desc='Dataset initialization', leave=False): try: - info = sf.info(folder+fn) + info = sf.info(os.path.join(args.audio_folder, fn)) duration, fs = info.duration, info.samplerate - self.samples.extend([{'fn':fn, 'offset':offset, 'fs':fs} for offset in np.arange(0, duration+.01-lensample, lensample)]) + self.samples.extend([{'fn':fn, 'offset':offset, 'fs':fs} for offset in np.arange(0, duration+.01 - args.lensample, args.lensample)]) except: continue assert info.channels > args.channel, f"The desired channel is unavailable for the audio file {fn}" @@ -45,29 +44,29 @@ class Dataset(torch.utils.data.Dataset): def __getitem__(self, idx): sample = self.samples[idx] try: - sig, fs = sf.read(self.folder+sample['fn'], start=int(sample['offset']*sample['fs']), stop=int((sample['offset']+self.lensample)*sample['fs']), always_2d=True) + sig, fs = sf.read(os.path.join(args.audio_folder, sample['fn']), start=int(sample['offset']*sample['fs']), stop=int((sample['offset']+args.lensample)*sample['fs']), always_2d=True) except: print('Failed loading '+sample['fn']) return None sig = sig[:, args.channel] - if fs != self.fs: - sig = signal.resample(sig, self.lensample*self.fs) + if fs != models.get[args.specie]['fs']: + sig = signal.resample(sig, args.lensample * models.get[args.specie]['fs']) sig = norm(sig) return torch.tensor(sig).float(), sample # prepare model model = models.get[args.specie]['archi'] -model.load_state_dict(torch.load(f"weights/{models.get[args.specie]['weights']}")) +model.load_state_dict(torch.load(f"{os.path.dirname(__file__)}/weights/{models.get[args.specie]['weights']}")) model.eval() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) # prepare data loader and output storage for predictions -loader = torch.utils.data.DataLoader(Dataset(args.audio_folder, models.get[args.specie]['fs'], args.lensample), +loader = torch.utils.data.DataLoader(Dataset(), batch_size=args.batch_size, collate_fn=collate_fn, num_workers=8, prefetch_factor=4) if len(loader) == 0: - print(f'Unable to open any audio file in the given folder {args.audiofolder}') + print(f'Unable to open any audio file in the given folder {args.audio_folder}') exit() out = pd.DataFrame(columns=['filename', 'offset', 'prediction']) @@ -92,4 +91,4 @@ print(f'Saving results into {pred_fn}') if args.maxPool: out.to_csv(pred_fn, index=False) else: - out.to_pickle(pred_fn) \ No newline at end of file + out.to_pickle(pred_fn)