From 73976daf87baee54fd167b8f7c40ea94637056a2 Mon Sep 17 00:00:00 2001
From: Paul Best <paul.best@lis-lab.fr>
Date: Wed, 28 Sep 2022 12:00:53 +0200
Subject: [PATCH] add output filename option

---
 run_CNN.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/run_CNN.py b/run_CNN.py
index 4b25fe5..93e441e 100644
--- a/run_CNN.py
+++ b/run_CNN.py
@@ -16,6 +16,7 @@ parser.add_argument('-batch_size', type=int, help='Amount of samples to process
 parser.add_argument('-channel', type=int, help='Channel of the audio file to use in the model inference (starting from 0)', default=0)
 parser.add_argument('-maxPool', help='Wether to keep only the maximal prediction of each sample or the full sequence', action='store_true'),
 parser.add_argument('-no-maxPool', dest='maxPool', action='store_false')
+parser.add_argument('-output_filename', type=str, help='Name of the output file for saving predictions', default='')
 parser.set_defaults(maxPool=True)
 args = parser.parse_args()
 
@@ -86,7 +87,7 @@ with torch.no_grad():
         offsets.extend(meta['offset'].numpy())
 
 out.filename, out.offset, out.prediction = fns, offsets, preds
-pred_fn = list(filter(lambda e: e!='', args.audio_folder.split('/')))[-1] + ('.csv' if args.maxPool else '.pkl')
+pred_fn = list(filter(lambda e: e!='', args.audio_folder.split('/')))[-1] + ('.csv' if args.maxPool else '.pkl') if args.output_filename == '' else args.output_filename
 print(f'Saving results into {pred_fn}')
 if args.maxPool:
     out.to_csv(pred_fn, index=False)
-- 
GitLab