diff --git a/run_CNN.py b/run_CNN.py index e87bb767f759b8b3e35306387f9db39fcaf03601..e15d1df2ecd14cb2521f8693f659165caf8e506c 100644 --- a/run_CNN.py +++ b/run_CNN.py @@ -89,7 +89,7 @@ with torch.no_grad(): out.filename, out.offset, out.prediction = fns, offsets, preds pred_fn = list(filter(lambda e: e!='', args.audio_folder.split('/')))[-1] + ('.csv' if args.maxPool else '.pkl') if args.output_filename == '' else args.output_filename print(f'Saving results into {pred_fn}') -if args.maxPool: +if pred_fn.endswith('csv'): out.to_csv(pred_fn, index=False) else: out.to_pickle(pred_fn)