Skip to content
Snippets Groups Projects
Select Git revision
  • 975f019cb2ce9686007dfe92165ba96d278bdb0b
  • master default protected
  • test-error_interval
3 results

gsrp_tdoa_hyperres.py

Blame
  • gsrp_tdoa_hyperres.py 13.07 KiB
    import os
    import sys
    import itertools
    import argparse
    from sklearn.pipeline import Pipeline
    from sklearn.preprocessing import PolynomialFeatures
    from sklearn.linear_model import LinearRegression
    import numpy as np
    from numpy.fft import rfft, irfft
    import scipy.signal as sg
    import soundfile as sf
    import c_corr
    from gsrp_smart_util import *
    from math import ceil
    from scipy.signal.windows import tukey
    
    try:
        from tqdm import trange
    except ImportError:
        trange = range
    
    
    class BColors:
        HEADER = '\033[95m'
        OKBLUE = '\033[94m'
        OKCYAN = '\033[96m'
        OKGREEN = '\033[92m'
        WARNING = '\033[93m'
        FAIL = '\033[91m'
        ENDC = '\033[0m'
        BOLD = '\033[1m'
        UNDERLINE = '\033[4m'
    
    
    def intlist(s):
        return list(map(int, s.split(',')))
    
    
    def slicer(down, up, ndim, n):
        index = np.mgrid[ndim * [slice(0, n)]]
        bounds = np.linspace(down, up, n + 1).astype(int)
        slices = np.asarray([slice(a, b)
                             for a, b in zip(bounds[:-1], bounds[1:])])
        return slices[index].reshape(ndim, -1).T
    
    
    def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True):
        num_channels = data.shape[1]
        num_channel_pairs = num_channels * (num_channels - 1) // 2
    
        data = np.pad(data, [((w_size - 1) // 2, (w_size - 1) // 2), (0, 0)], 'constant')
        win = tukey(w_size, 0.2)[:, np.newaxis]
        cc_size = min(w_size, int(2 * max_tdoa // decimate))
        v1 = np.empty(num_channel_pairs, np.int8)
        v2 = np.empty(num_channel_pairs, np.int8)
        mat = np.zeros((num_channel_pairs, num_channels - 1), np.int8)
        for k, (i, j) in enumerate(itertools.combinations(range(num_channels), 2)):
            if i > 0:
                mat[k, i - 1] = -1
            mat[k, j - 1] = 1
            v1[k] = i
            v2[k] = j
        dw_size = w_size // decimate
        if mode == 'prepare':
            slices = slicer(-(cc_size // 2), cc_size // 2, (num_channels - 1), 16)
            tausf = []
            for j in range(len(slices)):
                taus = np.mgrid[slices[j]].reshape(num_channels - 1, -1).astype(np.int16)
                taus2 = np.matmul(mat, taus)
                tausf += [taus2[:, np.abs(taus2).max(0) <= cc_size // 2]]
            tausf = np.hstack(tausf)
            tausf %= dw_size
        elif mode == 'on-the-fly':
            pass
        elif mode == 'smart':
            tree = gen_tree(num_channels - 1)
            program, clean_list = op_tree(tree)
        else:
            raise ValueError(f'Unknown mode {mode}')
    
        tdoas = np.zeros((len(pos), num_channel_pairs + 2), np.float32)
    
        if hyper:   # prepare hyper res
            tdoas2 = np.zeros((len(pos), num_channel_pairs + 2), np.float32)
            poly = PolynomialFeatures(2)
            lin = LinearRegression()
            pipe = Pipeline([('poly', poly), ('lin', lin)])
            ind = np.triu_indices(num_channels - 1)
    
            def _hyperres(taus, cc):
                taus = taus[:num_channels-1] + np.stack(np.meshgrid(*(num_channels - 1) * (np.arange(-2, 3),)), 0).reshape(
                    num_channels - 1, -1).T
                taus = np.matmul(mat, taus.T)
                taus = taus[:, np.abs(taus).max(0) <= cc_size // 2]
                mean = taus.mean(-1)[:num_channels-1]
                coef = pipe.fit(taus.T[:, :num_channels-1] - mean,
                                cc[np.expand_dims(np.arange(num_channel_pairs), 1), taus.astype(int)].prod(0)
                                ).named_steps['lin'].coef_
                der = np.zeros((num_channels - 1, num_channels - 1))
                der[ind] = coef[num_channels:]
                poly_min = np.linalg.lstsq(der + der.T, -coef[1:num_channels], rcond=None)[0]
                return np.log10(pipe.predict(poly_min[np.newaxis]).item()), mat @ (poly_min + mean)
    
        cc = np.empty((num_channel_pairs, dw_size), np.float32)
        for i in trange(len(pos)):
            fft = rfft(win * data[pos[i]:w_size + pos[i]], axis=0)
            if decimate > 1:
                fft = fft[:(len(fft) - 1) // decimate + 1]
            fft = np.asarray(fft, dtype=np.complex64)
            cc[:] = irfft(fft[:, v2] * np.conj(fft[:, v1]), axis=0).T
            cc -= cc.min(-1, keepdims=True)
            maxs = cc.max(1, keepdims=True)
            cc /= maxs
            maxs = np.log10(maxs.prod())
            if mode == 'prepare':
                tdoas[i, :2], index = c_corr.c_corr_at(cc, tausf)
                tdoas[i, 2:] = ((tausf[:, index] + dw_size // 2) % dw_size) - dw_size // 2
            elif mode == 'on-the-fly':
                tdoas[i, :2], tdoas[i, 2:] = c_corr.c_corr_all(cc, cc_size//2, num_channels - 1)
            elif mode == 'smart':
                tdoas[i, :2], tdoas[i, 2:] = smart_gsrp(cc, num_channels - 1, num_channel_pairs, cc_size // 2,
                                               tree, program, clean_list)
            else:
                raise ValueError(f'Unknown mode {mode}')
            tdoas[i, 1] += maxs
    
            if hyper:
                tdoas2[i, :2], tdoas2[i, 2:] = _hyperres(tdoas[i, 2:], cc)
                tdoas2[i, 1] += maxs
        tdoas[:, :2] *= 20
        if hyper:
            tdoas2[:, :2] *= 20
            return np.hstack((np.expand_dims(pos, -1), tdoas)), np.hstack((np.expand_dims(pos, -1), tdoas2))
        else:
            return np.hstack((np.expand_dims(pos, -1), tdoas))
    
    
    def main(args):
        if args.erase and os.path.isfile(args.outfile):
            print(f'{BColors.WARNING}{args.outfile} already exist and erase is not set {BColors.ENDC}')
            return 1
    
        # load audio file
        print(f'Loading {args.infile}...')
        sr = sf.info(args.infile).samplerate
        sound, sr = sf.read(args.infile, start=int(sr * args.start),
                            stop=int(sr * args.end) if args.end is not None else None,
                            dtype=np.float32, always_2d=True)
        if args.channels is not None:
            sound = sound[:, args.channels]
        else:
            args.channels = list(range(sound.shape[1]))
        if sound.shape[1] < 2:
            raise ValueError(f'{BColors.FAIL}{args.infile} with channels {args.channel} has not enough channels'
                             f'{BColors.ENDC}')
    
        if args.inverse is not None:
            for c in args.inverse:
                sound[:, c] *= -1
    
        if not (args.low is None and args.up is None):
            print("Filtering...")
            if args.low is not None:
                if args.up is None:
                    sos = sg.butter(3, 2 * args.low / sr, 'highpass', output='sos')
                else:
                    sos = sg.butter(3, [2 * args.low / sr, 2 * args.up / sr], 'bandpass', output='sos')
            else:
                sos = sg.butter(3, 2 * args.up / sr, 'lowpass', output='sos')
            sound = sg.sosfiltfilt(sos, sound, axis=0)
    
        if args.decimate and args.temporal:
            sound = sound[::args.decimate]
            sr /= args.decimate
    
        # Position where the TDOAs are computed
        if os.path.isfile(args.stride):
            pos = (sr * np.loadtxt(args.stride, delimiter=',')).astype(int).ravel()
        else:
            try:
                pos = np.arange(0, len(sound), int(sr * float(args.stride)))
            except ValueError:
                raise ValueError(f'Error: hop size {args.stride} is neither an existing file nor a float')
    
        print("Computing TDOAs...")
        results = corr(sound, pos, int(sr * args.frame_size), max_tdoa=int(np.ceil(sr * args.max_tdoa)),
                       decimate=args.decimate if not args.temporal else 1, mode=args.mode, hyper=not args.no_hyperres)
        if args.no_hyperres:
            result1 = results
        else:
            result1, result2 = results
    
        if args.outfile.endswith('.npy'):
            np.save(args.outfile, result1)
            if not args.no_hyperres:
                np.save(args.outfile[:-4] + '_2.npy', result2)
        else:
            np.savetxt(args.outfile, result1, delimiter=',')
            if not args.no_hyperres:
                np.savetxt((lambda x1, x2, x3: x1 + '_2' + x2 + x3)(*args.outfile.rpartition('.')),
                           result2, delimiter=',')
        print("Done.")
        return 0
    
    
    if __name__ == "__main__":
        class SmartFormatter(argparse.HelpFormatter):
            """
            Allow to change argparse formating behaviour for one option by adding \'R|\' at the start of the help
            """
    
            def _split_lines(self, text, width):
                if text.startswith('R|'):
                    return [l for t in text[2:].splitlines() for l in argparse.HelpFormatter._split_lines(self, t, width)]
                    # this is the RawTextHelpFormatter._split_lines
                return argparse.HelpFormatter._split_lines(self, text, width)
    
    
        parser = argparse.ArgumentParser(description='Computes TDOA estimates from a multi-channel recording.',
                                         formatter_class=SmartFormatter)
    
        parser.add_argument('infile', type=str, help='The sound file to process.')
        parser.add_argument('outfile', type=str, help='The text or npy file to write results to. Each row gives the '
                                                      'position (in samples), cross-correlation product in decibel '
                                                      '(normalized and unormalized), the independent TDOAs (in samples), '
                                                      'and TDOAs derived from the independent ones.')
    
        group1 = parser.add_argument_group('Channels')
        group1.add_argument('-c', '--channels', type=intlist, default=None,
                            help='The channels to cross-correlate. Accepts two or more,  but beware of high memory use. To '
                                 'be given as a comma-separated list of numbers, with 0 referring to the first channel '
                                 '(default: all channels).')
        group1.add_argument('-i', '--inverse', type=intlist, default=None,
                            help='Inverse the channel. To be given as a comma-separated list of numbers,'
                                 'with 0 referring to the first channel once channels have been chosen by --channels.')
    
        group2 = parser.add_argument_group('Size settings')
        group2.add_argument('-f', '--frame-size', type=float, default=0.02,
                            help='The size of the cross-correlation frames in seconds  (default: %(default)s)')
        group2_s = group2.add_mutually_exclusive_group()
        group2_s.add_argument('-s', '--stride', type=str, default='0.01', dest='stride',
                            help='The step between the beginnings of sequential frames  in seconds (default: %(default)s)')
        group2_s.add_argument('-p', '--pos', type=str, dest='stride',
                            help='The position in second from csv file path. Not allowed if stride is set')
        group2.add_argument('-m', '--max-tdoa', type=float, default=0.0011,
                            help='The maximum TDOA in seconds (default: %(default)s).')
    
        group2.add_argument('-S', '--start', metavar='SECONDS', type=float, default=0,
                            help='If given, only analyze from the given position.')
        group2.add_argument('-E', '--end', metavar='SECONDS', type=float, default=None,
                            help='If given, only analyze up to the given position.')
    
        group3 = parser.add_argument_group('Filtering')
        group3.add_argument('-l', '--low', type=float, default=None, help='Bottom cutoff frequency. Disabled by default.')
        group3.add_argument('-u', '--up', type=float, default=None, help='Top cutoff frequency. Disabled by default.')
        group3.add_argument('-d', '--decimate', type=int, default=1, help='Downsample the signal by the given factor. '
                                                                          'Disabled by default')
        group3.add_argument('-t', '--temporal', action='store_true', help='If given, any decimation will be applied in the '
                                                                          'time domain instead of the spectral domain.')
    
        group4 = parser.add_argument_group('Other')
        group4.add_argument('-e', '--erase', action='store_false', help='Erase existing outfile. If outfile exist and '
                                                                        '--erase is not provide, the script will exit.')
        group4.add_argument('-n', '--no-hyperres', action='store_true', help='Disable the hyper resolution evalutation of '
                                                                             'the TDOA')
        group4.add_argument('-M', '--mode', choices={'prepare', 'on-the-fly', 'smart', 'auto'}, default='smart',
                            help=f'R|How to explore the TDOA space (default: %(default)s).\n'
                                 f'{BColors.BOLD}prepare{BColors.ENDC} precomputes all the possible TDOA pairs and then '
                                 f'evaluate them. All the results are save in memory.\n'
                                 f'{BColors.BOLD}on-the-fly{BColors.ENDC} compute the TDOA pairs at the same time as it '
                                 f'compute the loss function. Only the maximum is saved. Can be slower than '
                                 f'{BColors.BOLD}prepare{BColors.ENDC}.\n'
                                 f'{BColors.BOLD}smart{BColors.ENDC} gradually increase the search space dimension, '
                                 f'reducing the number of tdoa to evaluate.\n'
                                 f'{BColors.BOLD}auto{BColors.ENDC} automatically try to pick the right method.')
    
        args = parser.parse_args()
        try:
            if args.mode in ['auto']:
                raise NotImplementedError(f'mode {args.mode} is not yet implemented')
    
            sys.exit(main(args))
    
        except Exception as e:
            print(type(e).__name__, e, sep=': ')
            sys.exit(2)