Skip to content
Snippets Groups Projects
Commit 941531f8 authored by ferrari's avatar ferrari
Browse files

Added error estimation

parent e2d96a7b
No related branches found
No related tags found
No related merge requests found
......@@ -105,7 +105,7 @@ def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True, ve
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
tdoas2 = np.zeros((len(pos), num_channel_pairs + 2), np.float32)
tdoas2 = np.zeros((len(pos), num_channel_pairs + 2 + (num_channels-1)**2), np.float32)
poly = PolynomialFeatures(2)
lin = LinearRegression()
pipe = Pipeline([('poly', poly), ('lin', lin)])
......@@ -124,7 +124,8 @@ def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True, ve
der[ind] = coef[num_channels:]
poly_min = np.linalg.lstsq(der + der.T, -coef[1:num_channels], rcond=None)[0]
with np.errstate(divide='ignore', invalid='ignore'):
return np.log10(pipe.predict(poly_min[np.newaxis]).item()), mat @ (poly_min + mean)
return np.log10(pipe.predict(poly_min[np.newaxis]).item()), mat @ (poly_min + mean),\
1/np.sqrt(abs(der + der.T)/1e-6).ravel()
cc = np.empty((num_channel_pairs, dw_size), np.float32)
for i in trange(len(pos)):
......@@ -152,7 +153,8 @@ def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True, ve
if hyper:
with np.errstate(divide='ignore'):
tdoas2[i, :2], tdoas2[i, 2:] = _hyperres(tdoas[i, 2:], cc)
tdoas2[i, :2], tdoas2[i, 2:num_channel_pairs + 2], tdoas2[i, num_channel_pairs + 2:] =\
_hyperres(tdoas[i, 2:], cc)
tdoas2[i, 1] += maxs
tdoas[:, :2] *= 20
if mode == 'smart':
......@@ -160,7 +162,8 @@ def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True, ve
f'{count} out of {len(pos)} TDOA have been fully computed{BColors.ENDC}')
if hyper:
tdoas2[:, :2] *= 20
return np.hstack((np.expand_dims(pos, -1), tdoas)), np.hstack((np.expand_dims(pos, -1), tdoas2))
return np.hstack((np.expand_dims(pos, -1), tdoas)), np.hstack((np.expand_dims(pos, -1), tdoas2)),\
'err_norm' + poly.get_feature_names_out([f't0{i}' for i in range(1, num_channels)])
else:
return np.hstack((np.expand_dims(pos, -1), tdoas))
......@@ -227,9 +230,9 @@ def main(args):
if args.no_hyperres:
result1 = results
else:
result1, result2 = results
result1, result2, err_names = results
result2[:, 0] /= sr
result2[:, 3:] /= sr if args.temporal else sr/args.decimate
result2[:, 3:-len(err_names)] /= sr if args.temporal else sr/args.decimate
result1[:, 0] /= sr
result1[:, 3:] /= sr if args.temporal else sr/args.decimate
columns = ','.join(['pos', 'db_norm', 'db'] + [f't{i}{j}'for i, j in combinations(range(sound.shape[1]), 2)])
......@@ -245,16 +248,16 @@ def main(args):
if args.no_hyperres:
df = DataFrame(result1, columns=columns)
elif args.wide:
columns = columns + ['h_' + c for c in columns[1:]]
columns = columns + ['h_' + c for c in columns[1:]] + err_names
df = DataFrame(np.concatenate([result1, result2[:, 1:]], axis=1), columns=columns)
else:
if ext in ('xls', 'xlsx', 'ods'):
from pandas import ExcelWriter
with ExcelWriter(args.outfile) as writer:
DataFrame(result1, columns=columns).to_excel(writer, sheet_name='Normal')
DataFrame(result2, columns=columns).to_excel(writer, sheet_name='Hyperres')
DataFrame(result2, columns=columns + err_names).to_excel(writer, sheet_name='Hyperres')
return 0
columns = [(h, c) for h in ('normal', 'hyperres') for c in columns[1:]]
columns = [(h, c) for h in ('normal', 'hyperres') for c in columns[1:]] + [('hyperres', e) for e in err_names]
df = DataFrame(np.concatenate([result1[:, 1:], result2[:, 1:]], axis=1),
columns=MultiIndex.from_tuples(columns), index=result1[:, 0])
if ext in ('h5', 'hdf'):
......@@ -268,11 +271,11 @@ def main(args):
np.savetxt(args.outfile, result1, delimiter=',', header=columns)
elif args.wide:
np.savetxt(args.outfile, np.concatenate([result1, result2[:, 1:]], axis=1), delimiter=',',
header=',h_'.join([columns] + columns.split(',')[1:]))
header=',h_'.join([columns] + columns.split(',')[1:])) + err_names
else:
np.savetxt(args.outfile, np.concatenate([result1, result2[:, 1:]], axis=1), delimiter=',',
header=','.join([' '] + (result1.shape[1]-1)*['normal'] + (result1.shape[1]-1)*['hyperres']) +
'\n' + ',' + columns[4:] + ',' + columns[4:], comments='')
header=','.join([' '] + (result1.shape[1]-1)*['normal'] + (result1.shape[1]-1 + len(err_names))*['hyperres']) +
'\n' + ',' + columns[4:] + ',' + columns[4:] + err_names, comments='')
print("Done.")
return 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment