Skip to content
Snippets Groups Projects
Commit 73976daf authored by Paul Best's avatar Paul Best
Browse files

add output filename option

parent a071969a
No related branches found
No related tags found
No related merge requests found
...@@ -16,6 +16,7 @@ parser.add_argument('-batch_size', type=int, help='Amount of samples to process ...@@ -16,6 +16,7 @@ parser.add_argument('-batch_size', type=int, help='Amount of samples to process
parser.add_argument('-channel', type=int, help='Channel of the audio file to use in the model inference (starting from 0)', default=0) parser.add_argument('-channel', type=int, help='Channel of the audio file to use in the model inference (starting from 0)', default=0)
parser.add_argument('-maxPool', help='Wether to keep only the maximal prediction of each sample or the full sequence', action='store_true'), parser.add_argument('-maxPool', help='Wether to keep only the maximal prediction of each sample or the full sequence', action='store_true'),
parser.add_argument('-no-maxPool', dest='maxPool', action='store_false') parser.add_argument('-no-maxPool', dest='maxPool', action='store_false')
parser.add_argument('-output_filename', type=str, help='Name of the output file for saving predictions', default='')
parser.set_defaults(maxPool=True) parser.set_defaults(maxPool=True)
args = parser.parse_args() args = parser.parse_args()
...@@ -86,7 +87,7 @@ with torch.no_grad(): ...@@ -86,7 +87,7 @@ with torch.no_grad():
offsets.extend(meta['offset'].numpy()) offsets.extend(meta['offset'].numpy())
out.filename, out.offset, out.prediction = fns, offsets, preds out.filename, out.offset, out.prediction = fns, offsets, preds
pred_fn = list(filter(lambda e: e!='', args.audio_folder.split('/')))[-1] + ('.csv' if args.maxPool else '.pkl') pred_fn = list(filter(lambda e: e!='', args.audio_folder.split('/')))[-1] + ('.csv' if args.maxPool else '.pkl') if args.output_filename == '' else args.output_filename
print(f'Saving results into {pred_fn}') print(f'Saving results into {pred_fn}')
if args.maxPool: if args.maxPool:
out.to_csv(pred_fn, index=False) out.to_csv(pred_fn, index=False)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment