Skip to content
Snippets Groups Projects
Commit c1c4b23f authored by ferrari's avatar ferrari
Browse files

Fix for number of channels not equal to 4

parent 17891454
No related branches found
No related tags found
No related merge requests found
......@@ -50,14 +50,16 @@ def gen_tree(size):
def op_tree(tree):
program = [list() for _ in range(len(tree))]
clean_list = [list() for _ in range(len(tree) + 1)]
clean_list = [list() for _ in range(len(tree))]
program[0] = [Operation('mem', 0, group[0], len(tree)) for group in tree[0]]
for i in range(1, len(tree)):
for group in tree[i]:
if group[:-1] in tree[i-1]:
j = tree[i-1].index(group[:-1])
program[i-1][j].lifetime = i
program[i].append(Operation('mul', j, tree[0].index(group[-1:]), i))
j1 = tree[i-1].index(group[:-1])
j2 = tree[0].index(group[-1:])
program[i].append(Operation('mul', j1, j2, i))
program[i-1][j1].lifetime = i
program[0][j2].lifetime = i
else:
for j in range(i-1, -1, -1):
if group in tree[j]:
......
......@@ -92,13 +92,13 @@ def corr(data, pos, w_size, max_tdoa, decimate=1, mode='prepare', hyper=True):
num_channels - 1, -1).T
taus = np.matmul(mat, taus.T)
taus = taus[:, np.abs(taus).max(0) <= cc_size // 2]
mean = taus.mean(-1)[:3]
coef = pipe.fit(taus.T[:, :3] - mean,
mean = taus.mean(-1)[:num_channels-1]
coef = pipe.fit(taus.T[:, :num_channels-1] - mean,
cc[np.expand_dims(np.arange(num_channel_pairs), 1), taus.astype(int)].prod(0)
).named_steps['lin'].coef_
der = np.zeros((num_channels - 1, num_channels - 1))
der[ind] = coef[4:]
poly_min = np.linalg.lstsq(der + der.T, -coef[1:4], rcond=None)[0]
der[ind] = coef[num_channels:]
poly_min = np.linalg.lstsq(der + der.T, -coef[1:num_channels], rcond=None)[0]
return np.log10(pipe.predict(poly_min[np.newaxis]).item()), mat @ (poly_min + mean)
cc = np.empty((num_channel_pairs, dw_size), np.float32)
......@@ -237,9 +237,11 @@ if __name__ == "__main__":
group2 = parser.add_argument_group('Size settings')
group2.add_argument('-f', '--frame-size', type=float, default=0.02,
help='The size of the cross-correlation frames in seconds (default: %(default)s)')
group2.add_argument('-s', '--stride', type=str, default='0.01',
help='The step between the beginnings of sequential frames in seconds (default: %(default)s), '
'or the postion in second if csv file path is given.')
group2_s = group2.add_mutually_exclusive_group()
group2_s.add_argument('-s', '--stride', type=str, default='0.01',
help='The step between the beginnings of sequential frames in seconds (default: %(default)s)')
group2_s.add_argument('-p', '--pos', type=str, default='pos.csv',
help='The position in second from csv file path (default: %(default)s)')
group2.add_argument('-m', '--max-tdoa', type=float, default=0.0011,
help='The maximum TDOA in seconds (default: %(default)s).')
......@@ -280,5 +282,5 @@ if __name__ == "__main__":
sys.exit(main(args))
except Exception as e:
print(e)
print(type(e).__name__, e, sep=': ')
sys.exit(2)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment