From 2ee1b3f38f66a9ad5746eebb6552e279427594cf Mon Sep 17 00:00:00 2001
From: lamipaul <paulobest25@gmail.com>
Date: Fri, 26 Aug 2022 14:23:56 +0200
Subject: [PATCH] add channel option

---
 run_CNN.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/run_CNN.py b/run_CNN.py
index 4a4e071..d86a610 100644
--- a/run_CNN.py
+++ b/run_CNN.py
@@ -13,6 +13,7 @@ parser.add_argument('audio_folder', type=str, help='Path of the folder with audi
 parser.add_argument('specie', type=str, help='Target specie to detect', choices=['megaptera', 'delphinid', 'orcinus', 'physeter', 'balaenoptera'])
 parser.add_argument('-lensample', type=float, help='Length of the signal for each sample (in seconds)', default=5),
 parser.add_argument('-batch_size', type=int, help='Amount of samples to process at a time (usefull for parallel computation using a GPU)', default=32),
+parser.add_argument('-channel', type=int, help='Channel of the audio file to use in the model inference (starting from 0)', default=0)
 parser.add_argument('-maxPool', help='Wether to keep only the maximal prediction of each sample or the full sequence', action='store_true'),
 parser.add_argument('-no-maxPool', dest='maxPool', action='store_false')
 parser.set_defaults(maxPool=True)
@@ -37,6 +38,7 @@ class Dataset(torch.utils.data.Dataset):
                 self.samples.extend([{'fn':fn, 'offset':offset, 'fs':fs} for offset in np.arange(0, duration+.01-lensample, lensample)])
             except:
                 continue
+            assert info.channels > args.channel, f"The desired channel is unavailable for the audio file {fn}"
     def __len__(self):
         return len(self.samples)
 
@@ -47,7 +49,7 @@ class Dataset(torch.utils.data.Dataset):
         except:
             print('Failed loading '+sample['fn'])
             return None
-        sig = sig[:,0]
+        sig = sig[:, args.channel]
         if fs != self.fs:
             sig = signal.resample(sig, self.lensample*self.fs)
         sig = norm(sig)
-- 
GitLab