diff --git a/gsrp_tdoa_hyperres.py b/gsrp_tdoa_hyperres.py index 5e51a559f4e4137349d2e5cc73d894835dfde4ea..7a6318e27dc755986f5291178f0fe30cc0247128 100755 --- a/gsrp_tdoa_hyperres.py +++ b/gsrp_tdoa_hyperres.py @@ -105,7 +105,7 @@ def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True, ve from sklearn.pipeline import Pipeline from sklearn.preprocessing import PolynomialFeatures - tdoas2 = np.zeros((len(pos), num_channel_pairs + 2), np.float32) + tdoas2 = np.zeros((len(pos), num_channel_pairs + 2 + (num_channels-1)**2), np.float32) poly = PolynomialFeatures(2) lin = LinearRegression() pipe = Pipeline([('poly', poly), ('lin', lin)]) @@ -124,7 +124,8 @@ def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True, ve der[ind] = coef[num_channels:] poly_min = np.linalg.lstsq(der + der.T, -coef[1:num_channels], rcond=None)[0] with np.errstate(divide='ignore', invalid='ignore'): - return np.log10(pipe.predict(poly_min[np.newaxis]).item()), mat @ (poly_min + mean) + return np.log10(pipe.predict(poly_min[np.newaxis]).item()), mat @ (poly_min + mean),\ + 1/np.sqrt(abs(der + der.T)/1e-6).ravel() cc = np.empty((num_channel_pairs, dw_size), np.float32) for i in trange(len(pos)): @@ -152,7 +153,8 @@ def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True, ve if hyper: with np.errstate(divide='ignore'): - tdoas2[i, :2], tdoas2[i, 2:] = _hyperres(tdoas[i, 2:], cc) + tdoas2[i, :2], tdoas2[i, 2:num_channel_pairs + 2], tdoas2[i, num_channel_pairs + 2:] =\ + _hyperres(tdoas[i, 2:], cc) tdoas2[i, 1] += maxs tdoas[:, :2] *= 20 if mode == 'smart': @@ -160,7 +162,8 @@ def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True, ve f'{count} out of {len(pos)} TDOA have been fully computed{BColors.ENDC}') if hyper: tdoas2[:, :2] *= 20 - return np.hstack((np.expand_dims(pos, -1), tdoas)), np.hstack((np.expand_dims(pos, -1), tdoas2)) + return np.hstack((np.expand_dims(pos, -1), tdoas)), np.hstack((np.expand_dims(pos, -1), tdoas2)),\ + 'err_norm' + poly.get_feature_names_out([f't0{i}' for i in range(1, num_channels)]) else: return np.hstack((np.expand_dims(pos, -1), tdoas)) @@ -227,9 +230,9 @@ def main(args): if args.no_hyperres: result1 = results else: - result1, result2 = results + result1, result2, err_names = results result2[:, 0] /= sr - result2[:, 3:] /= sr if args.temporal else sr/args.decimate + result2[:, 3:-len(err_names)] /= sr if args.temporal else sr/args.decimate result1[:, 0] /= sr result1[:, 3:] /= sr if args.temporal else sr/args.decimate columns = ','.join(['pos', 'db_norm', 'db'] + [f't{i}{j}'for i, j in combinations(range(sound.shape[1]), 2)]) @@ -245,16 +248,16 @@ def main(args): if args.no_hyperres: df = DataFrame(result1, columns=columns) elif args.wide: - columns = columns + ['h_' + c for c in columns[1:]] + columns = columns + ['h_' + c for c in columns[1:]] + err_names df = DataFrame(np.concatenate([result1, result2[:, 1:]], axis=1), columns=columns) else: if ext in ('xls', 'xlsx', 'ods'): from pandas import ExcelWriter with ExcelWriter(args.outfile) as writer: DataFrame(result1, columns=columns).to_excel(writer, sheet_name='Normal') - DataFrame(result2, columns=columns).to_excel(writer, sheet_name='Hyperres') + DataFrame(result2, columns=columns + err_names).to_excel(writer, sheet_name='Hyperres') return 0 - columns = [(h, c) for h in ('normal', 'hyperres') for c in columns[1:]] + columns = [(h, c) for h in ('normal', 'hyperres') for c in columns[1:]] + [('hyperres', e) for e in err_names] df = DataFrame(np.concatenate([result1[:, 1:], result2[:, 1:]], axis=1), columns=MultiIndex.from_tuples(columns), index=result1[:, 0]) if ext in ('h5', 'hdf'): @@ -268,11 +271,11 @@ def main(args): np.savetxt(args.outfile, result1, delimiter=',', header=columns) elif args.wide: np.savetxt(args.outfile, np.concatenate([result1, result2[:, 1:]], axis=1), delimiter=',', - header=',h_'.join([columns] + columns.split(',')[1:])) + header=',h_'.join([columns] + columns.split(',')[1:])) + err_names else: np.savetxt(args.outfile, np.concatenate([result1, result2[:, 1:]], axis=1), delimiter=',', - header=','.join([' '] + (result1.shape[1]-1)*['normal'] + (result1.shape[1]-1)*['hyperres']) + - '\n' + ',' + columns[4:] + ',' + columns[4:], comments='') + header=','.join([' '] + (result1.shape[1]-1)*['normal'] + (result1.shape[1]-1 + len(err_names))*['hyperres']) + + '\n' + ',' + columns[4:] + ',' + columns[4:] + err_names, comments='') print("Done.") return 0