Skip to content
Snippets Groups Projects
Commit cdd646f5 authored by ferrari's avatar ferrari
Browse files

Improved speed and memory consumption of smart mode

parent 4dee8c32
No related branches found
No related tags found
No related merge requests found
......@@ -85,11 +85,29 @@ def dep_tdoa(tij, nind, ntot):
i0 = j0 + 1
def num_ind(i, j, n_ind):
def num_ind_dep(i, j, n_ind):
return j*(n_ind-1) + i-1 - (j*(j+1))//2 + n_ind
def mul(mem1, mem2, cc, t_max, id1, id2, n_ind, mem_limit=np.infty):
def inv_num_ind(k, n_ind):
j = int((2*n_ind + 1 - ((2*n_ind + 1)**2 - 8*k)**0.5)/2)
return k - j*n_ind + j*(j+1)//2 + 1, j
def add_all(memory, tree, cc, t_max, n_ind, n_tot):
values = list(memory.values())
idx = np.meshgrid(*[np.arange(len(v[0])) for v in values], copy=False)
out_tij = np.empty((n_tot, idx[0].size), values[0][1].dtype)
for ((k0, k1), (v, t)), i in zip(memory.items(), idx): # all ind_tdoa should be present in memory
out_tij[tree[k0][k1]] = np.take(t, i, axis=1).reshape(-1, i.size)
dep_tdoa(out_tij, n_ind, n_tot)
out_tij = out_tij[:, (np.abs(out_tij) <= t_max).all(0)]
val = cc[np.arange(n_tot)[:, None], out_tij].sum(0)
t = np.argmax(val)
return val[t], out_tij[:, t]
def add(mem1, mem2, cc, t_max, id1, id2, n_ind, mem_limit=np.infty):
# assume len(id2) == 1
if (2 + len(id1)) * mem1[0].size * mem2[0].size * mem1[0].itemsize > mem_limit:
return
......@@ -104,7 +122,7 @@ def mul(mem1, mem2, cc, t_max, id1, id2, n_ind, mem_limit=np.infty):
out_val = mem1[0][idx1] + mem2[0][idx2]
tij_dep = out_tij[:-1] - out_tij[-1:]
tij_dep *= np.array([1 if i > id2 else -1 for i in id1])[:, np.newaxis]
ch_dep = np.array([num_ind(i, id2, n_ind) if i > id2 else num_ind(id2, i, n_ind) for i in id1])
ch_dep = np.array([num_ind_dep(i, id2, n_ind) if i > id2 else num_ind_dep(id2, i, n_ind) for i in id1])
out_val += cc[ch_dep[:, np.newaxis], tij_dep].sum(0)
return out_val, out_tij
......@@ -132,12 +150,19 @@ def constrained_argmax(mem, cc, tij_ind, curr_tij, used_tij, t_max, n_ind):
for u in used_tij:
for i, c in enumerate(curr_tij):
if u < c:
mem_val += cc[num_ind(c, u, n_ind), mem_tij[i] - tij_ind[u]]
mem_val += cc[num_ind_dep(c, u, n_ind), mem_tij[i] - tij_ind[u]]
else:
mem_val += cc[num_ind(u, c, n_ind), tij_ind[u] - mem_tij[i]]
mem_val += cc[num_ind_dep(u, c, n_ind), tij_ind[u] - mem_tij[i]]
return mem_tij[:, np.argmax(mem_val)]
def truncated_argmax(cc, t_max):
argmax_plus = np.argmax(cc[:, :t_max], axis=1)
argmax_minus = np.argmax(cc[:, -t_max:], axis=1) - t_max
x = np.arange(len(argmax_minus))
return np.where(cc[x, argmax_minus] < cc[x, argmax_plus], argmax_plus, argmax_minus)
def _get_mem_size(memory):
return sum(len(o[0].T) for o in memory.values())
......@@ -150,45 +175,76 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list, verbose=False
memory = dict()
with np.errstate(divide='ignore'):
cc = np.log10(cc)
val = cc[:, 0].sum()
out_val = cc[:, 0].sum()
tij = np.zeros(n_tot, int)
out_tij = tij.copy()
tij[:n_ind] = truncated_argmax(cc[:n_ind], t_max)
dep_tdoa(tij, n_ind, n_tot)
if np.all(np.abs(tij) <= t_max):
val = cc[np.arange(n_tot), tij].sum()
if val > out_val:
out_tij = tij.copy()
out_val = val
if out_val < -1: # out_val < -20dB, trying with other channel as independent base
_sort = (np.argsort(np.asarray(np.triu_indices(n_ind + 1, 1)).ravel()) % tij.size)[n_ind:]
# _sort should be precomputed since fixed by n_ind, but is only 40µs for n_ind =4
ch = np.argmin(cc[np.arange(n_tot), tij][_sort].reshape(-1, n_ind).sum(1))
other_ch = np.asarray([i for i in range(n_ind) if i != ch])
dep = [num_ind_dep(i, ch, n_ind) if i > ch else num_ind_dep(ch, i, n_ind) for i in other_ch]
t_dep = truncated_argmax(cc[dep], t_max)
tij[other_ch] = (-1)**(other_ch < ch) * t_dep + tij[ch]
dep_tdoa(tij, n_ind, n_tot)
if np.all(np.abs(tij) <= t_max):
val = cc[np.arange(n_tot), tij].sum()
if val > out_val:
out_tij = tij.copy()
out_val = val # if this point is reach, it might be useful to swap(0, ch)
for i, step in enumerate(program):
# increase dimensions
for k, v in memory.items():
curr_tij = frozenset(tree[k[0]][k[1]])
memory[k] = mask_val(v, out_val - sum(m for tij, m in max_val.items() if curr_tij.isdisjoint(tij)))
if i and np.prod([v[0].size for v in memory.values()], dtype=np.float64) <= 1024: # float to prevent overflow
return *add_all(memory, tree, cc, t_max, n_ind, n_tot), 1
for j, op in enumerate(step):
if op == 'mem':
if i == 0:
memory[(0, j)] = mask_val((np.concatenate((cc[op.right, -t_max:], cc[op.right, :t_max + 1])),
np.arange(-t_max, t_max + 1, dtype=np.int32)[np.newaxis]), val)
memory[0, j] = mask_val((np.concatenate((cc[op.right, -t_max:], cc[op.right, :t_max + 1])),
np.arange(-t_max, t_max + 1, dtype=np.int32)[np.newaxis]), out_val)
else:
memory[(i, j)] = mask_val(memory[(op.left, op.right)], val)
memory[i, j] = memory[op.left, op.right]
else: # op == 'mul'
memory[(i, j)] = mul(mask_val(memory[(i-1, op.left)], val), mask_val(memory[(0, op.right)], val),
cc, t_max, tree[i - 1][op.left], tree[0][op.right][0], n_ind,
mem_limit=mem_limit - _get_mem_usage(memory))
if memory[(i, j)] is None: # means that the memory limit has been reach
return val, tij, 0
memory[i, j] = add(memory[i - 1, op.left], memory[0, op.right], cc, t_max, tree[i - 1][op.left],
tree[0][op.right][0], n_ind, mem_limit=mem_limit - _get_mem_usage(memory))
if memory[i, j] is None: # means that the memory limit has been reach
return out_val, out_tij, 0
# find potential maximum
tij[:] = 0
done_tij = set()
max_val = {frozenset(tree[i][j]): memory[i, j][0].max() for j in range(len(step))}
try:
for j in range(len(step)):
curr_tij = tree[i][j]
tij[curr_tij] = constrained_argmax(memory[(i, j)], cc, tij[:n_ind], curr_tij, done_tij, t_max, n_ind)
tij[curr_tij] = constrained_argmax(memory[i, j], cc, tij[:n_ind], curr_tij, done_tij, t_max, n_ind)
done_tij.update(curr_tij)
dep_tdoa(tij, n_ind, n_tot)
val = max(cc[np.arange(n_tot), tij].sum(), val)
val = cc[np.arange(n_tot), tij].sum()
if val > out_val:
out_tij = tij.copy()
out_val = val
except ValueError: # search of potential maxima ended outside of possible values
tij_min, tij_max = memory[(i, j)][1].min(), memory[(i, j)][1].max()
tij_min, tij_max = memory[i, j][1].min(), memory[i, j][1].max()
for k in range(j):
memory[(i, k)] = mask_lim(memory[(i, k)], tij_min, tij_max, t_max)
memory[i, k] = mask_lim(memory[i, k], tij_min, tij_max, t_max)
if verbose:
mem_size = _get_mem_size(memory)
tqdm.write(f'TDOA: {tij}, val: {20*val:7.3f}dB, mem size: {mem_size} items, {_get_mem_usage(memory):3.2e} octets,'
tqdm.write(f'TDOA: {out_tij}, val: {20*out_val:7.3f}dB, mem size: {mem_size} items, '
f'{_get_mem_usage(memory):3.2e} octets, '
f'{100 * mem_size / (n_ind // (i + 1)) / (2 * t_max + 1) ** (i + 1):.4}%')
# Mem clean up
for p in clean_list[i]:
del memory[p]
return val, tij, 1
return out_val, out_tij, 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment