diff --git a/run_CNN.py b/run_CNN.py index 4b25fe538611f9ac269f48d953d71ae0ec341edf..93e441ed60734ace049ca4549c130ac3af241793 100644 --- a/run_CNN.py +++ b/run_CNN.py @@ -16,6 +16,7 @@ parser.add_argument('-batch_size', type=int, help='Amount of samples to process parser.add_argument('-channel', type=int, help='Channel of the audio file to use in the model inference (starting from 0)', default=0) parser.add_argument('-maxPool', help='Wether to keep only the maximal prediction of each sample or the full sequence', action='store_true'), parser.add_argument('-no-maxPool', dest='maxPool', action='store_false') +parser.add_argument('-output_filename', type=str, help='Name of the output file for saving predictions', default='') parser.set_defaults(maxPool=True) args = parser.parse_args() @@ -86,7 +87,7 @@ with torch.no_grad(): offsets.extend(meta['offset'].numpy()) out.filename, out.offset, out.prediction = fns, offsets, preds -pred_fn = list(filter(lambda e: e!='', args.audio_folder.split('/')))[-1] + ('.csv' if args.maxPool else '.pkl') +pred_fn = list(filter(lambda e: e!='', args.audio_folder.split('/')))[-1] + ('.csv' if args.maxPool else '.pkl') if args.output_filename == '' else args.output_filename print(f'Saving results into {pred_fn}') if args.maxPool: out.to_csv(pred_fn, index=False)