diff --git a/run_CNN.py b/run_CNN.py index 4a4e07147782c9f3fb71b775f60b403467ab3b41..d86a6103c7ee7087a2e95a4eab70cc2e397b477c 100644 --- a/run_CNN.py +++ b/run_CNN.py @@ -13,6 +13,7 @@ parser.add_argument('audio_folder', type=str, help='Path of the folder with audi parser.add_argument('specie', type=str, help='Target specie to detect', choices=['megaptera', 'delphinid', 'orcinus', 'physeter', 'balaenoptera']) parser.add_argument('-lensample', type=float, help='Length of the signal for each sample (in seconds)', default=5), parser.add_argument('-batch_size', type=int, help='Amount of samples to process at a time (usefull for parallel computation using a GPU)', default=32), +parser.add_argument('-channel', type=int, help='Channel of the audio file to use in the model inference (starting from 0)', default=0) parser.add_argument('-maxPool', help='Wether to keep only the maximal prediction of each sample or the full sequence', action='store_true'), parser.add_argument('-no-maxPool', dest='maxPool', action='store_false') parser.set_defaults(maxPool=True) @@ -37,6 +38,7 @@ class Dataset(torch.utils.data.Dataset): self.samples.extend([{'fn':fn, 'offset':offset, 'fs':fs} for offset in np.arange(0, duration+.01-lensample, lensample)]) except: continue + assert info.channels > args.channel, f"The desired channel is unavailable for the audio file {fn}" def __len__(self): return len(self.samples) @@ -47,7 +49,7 @@ class Dataset(torch.utils.data.Dataset): except: print('Failed loading '+sample['fn']) return None - sig = sig[:,0] + sig = sig[:, args.channel] if fs != self.fs: sig = signal.resample(sig, self.lensample*self.fs) sig = norm(sig)