diff --git a/gsrp_tdoa_hyperres.py b/gsrp_tdoa_hyperres.py index c603a56e46eeef1a51d61df0e6873bf55de8893a..3622fd0bc9c0bdd08f26baa5cbfb2daf755fc346 100755 --- a/gsrp_tdoa_hyperres.py +++ b/gsrp_tdoa_hyperres.py @@ -1,8 +1,9 @@ import argparse -import itertools +from itertools import combinations import os import sys +import numpy as np import scipy.signal as sg import soundfile as sf from numpy.fft import rfft, irfft @@ -40,7 +41,7 @@ except ImportError: trange = range -def intlist(s): +def intlist(s: str): return list(map(int, s.split(','))) @@ -75,7 +76,7 @@ def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True, ve v1 = np.empty(num_channel_pairs, np.int8) v2 = np.empty(num_channel_pairs, np.int8) mat = np.zeros((num_channel_pairs, num_channels - 1), np.int8) - for k, (i, j) in enumerate(itertools.combinations(range(num_channels), 2)): + for k, (i, j) in enumerate(combinations(range(num_channels), 2)): if i > 0: mat[k, i - 1] = -1 mat[k, j - 1] = 1 @@ -216,16 +217,50 @@ def main(args): result1 = results else: result1, result2 = results - - if args.outfile.endswith('.npy'): + result2[:, 0] /= sr + result2[:, 3:] /= sr if args.temporal else sr/args.decimate + result1[:, 0] /= sr + result1[:, 3:] /= sr if args.temporal else sr/args.decimate + columns = ','.join(['pos', 'db_norm', 'db'] + [f't{i}{j}'for i, j in combinations(range(sound.shape[1]), 2)]) + stem, ext = args.outfile.rsplit('.', 1) + if ext == 'npy': np.save(args.outfile, result1) if not args.no_hyperres: - np.save(args.outfile[:-4] + '_2.npy', result2) - else: - np.savetxt(args.outfile, result1, delimiter=',') - if not args.no_hyperres: - np.savetxt((lambda x1, x2, x3: x1 + '_2' + x2 + x3)(*args.outfile.rpartition('.')), - result2, delimiter=',') + np.save(stem + '_2.' + ext, result2) + elif ext in ('h5', 'hdf', 'pkl', 'xls', 'xlsx', 'ods'): + from pandas import DataFrame, MultiIndex + columns = columns.split(',') + if args.no_hyperres: + df = DataFrame(result1, columns=columns) + elif args.wide: + columns = columns + ['h_' + c for c in columns[1:]] + df = DataFrame(np.concatenate([result1, result2[:, 1:]], axis=1), columns=columns) + else: + if ext in ('xls', 'xlsx', 'ods'): + from pandas import ExcelWriter + with ExcelWriter(args.outfile) as writer: + DataFrame(result1, columns=columns).to_excel(writer, sheet_name='Normal') + DataFrame(result2, columns=columns).to_excel(writer, sheet_name='Hyperres') + return 0 + columns = [(h, c) for h in ('normal', 'hyperres') for c in columns[1:]] + df = DataFrame(np.concatenate([result1[:, 1:], result2[:, 1:]], axis=1), + columns=MultiIndex.from_tuples(columns), index=result1[:, 0]) + if ext in ('h5', 'hdf'): + df.to_hdf(args.outfile, 'df') + elif ext == 'pkl': + df.to_pickle(args.outfile) + else: + df.to_excel(args.outfile) + else: # assuming some kind of csv + if args.no_hyperres: + np.savetxt(args.outfile, result1, delimiter=',', header=columns) + elif args.wide: + np.savetxt(args.outfile, np.concatenate([result1, result2[:, 1:]], axis=1), delimiter=',', + header=',h_'.join([columns] + columns.split(',')[1:])) + else: + np.savetxt(args.outfile, np.concatenate([result1, result2[:, 1:]], axis=1), delimiter=',', + header=','.join([' '] + (result1.shape[1]-1)*['normal'] + (result1.shape[1]-1)*['hyperres']) + + '\n' + ',' + columns[4:] + ',' + columns[4:], comments='') print("Done.") return 0 @@ -263,12 +298,12 @@ if __name__ == "__main__": group2 = parser.add_argument_group('Size settings') group2.add_argument('-f', '--frame-size', type=float, default=0.02, - help='The size of the cross-correlation frames in seconds (default: %(default)s)') + help='The size of the cross-correlation frames in seconds (default: %(default)s)') group2_s = group2.add_mutually_exclusive_group() group2_s.add_argument('-s', '--stride', type=str, default='0.01', dest='stride', - help='The step between the beginnings of sequential frames in seconds (default: %(default)s)') + help='The step between the beginnings of sequential frames in seconds (default: %(default)s)') group2_s.add_argument('-p', '--pos', type=str, dest='stride', - help='The position in second from csv file path. Not allowed if stride is set') + help='The position in second from csv file path. Not allowed if stride is set') group2.add_argument('-m', '--max-tdoa', type=float, default=0.0011, help='The maximum TDOA in seconds (default: %(default)s).') @@ -290,6 +325,9 @@ if __name__ == "__main__": '--erase is not provide, the script will exit.') group4.add_argument('-n', '--no-hyperres', action='store_true', help='Disable the hyper resolution evalutation of ' 'the TDOA') + group4.add_argument('-w', '--wide', action='store_true', + help='Use only one level to concatenate the normal and hyperres results. Behaviour depends on ' + 'the output file type.') group4.add_argument('-M', '--mode', choices={'prepare', 'on-the-fly', 'smart', 'auto'}, default='smart', help=f'R|How to explore the TDOA space (default: %(default)s).\n' f'{BColors.BOLD}prepare{BColors.ENDC} precomputes all the possible TDOA pairs and then '