diff --git a/gsrp_smart_util.py b/gsrp_smart_util.py index eb0a8613cf4aea4fe03424f65c3e6e1a9c22e08e..a8b470e1d4dc078d51358465dbbb29c8a47d178e 100644 --- a/gsrp_smart_util.py +++ b/gsrp_smart_util.py @@ -57,7 +57,7 @@ def op_tree(tree): for group in tree[i]: if group[:-1] in tree[i-1]: j1 = tree[i-1].index(group[:-1]) - j2 = tree[0].index(group[-1:]) + j2 = tree[0].index(group[-1:]) program[i].append(Operation('mul', j1, j2, i)) program[i-1][j1].lifetime = i program[0][j2].lifetime = i @@ -85,11 +85,29 @@ def dep_tdoa(tij, nind, ntot): i0 = j0 + 1 -def num_ind(i, j, n_ind): +def num_ind_dep(i, j, n_ind): return j*(n_ind-1) + i-1 - (j*(j+1))//2 + n_ind -def mul(mem1, mem2, cc, t_max, id1, id2, n_ind, mem_limit=np.infty): +def inv_num_ind(k, n_ind): + j = int((2*n_ind + 1 - ((2*n_ind + 1)**2 - 8*k)**0.5)/2) + return k - j*n_ind + j*(j+1)//2 + 1, j + + +def add_all(memory, tree, cc, t_max, n_ind, n_tot): + values = list(memory.values()) + idx = np.meshgrid(*[np.arange(len(v[0])) for v in values], copy=False) + out_tij = np.empty((n_tot, idx[0].size), values[0][1].dtype) + for ((k0, k1), (v, t)), i in zip(memory.items(), idx): # all ind_tdoa should be present in memory + out_tij[tree[k0][k1]] = np.take(t, i, axis=1).reshape(-1, i.size) + dep_tdoa(out_tij, n_ind, n_tot) + out_tij = out_tij[:, (np.abs(out_tij) <= t_max).all(0)] + val = cc[np.arange(n_tot)[:, None], out_tij].sum(0) + t = np.argmax(val) + return val[t], out_tij[:, t] + + +def add(mem1, mem2, cc, t_max, id1, id2, n_ind, mem_limit=np.infty): # assume len(id2) == 1 if (2 + len(id1)) * mem1[0].size * mem2[0].size * mem1[0].itemsize > mem_limit: return @@ -98,13 +116,13 @@ def mul(mem1, mem2, cc, t_max, id1, id2, n_ind, mem_limit=np.infty): out_tij = np.empty((len(id1) + 1, len(idx1)), mem1[1].dtype) np.take(mem1[1], idx1, axis=1, out=out_tij[:-1]) np.take(mem2[1], idx2, axis=1, out=out_tij[-1:]) - mask = reduce_all(np.abs(out_tij[:-1] - out_tij[-1:]) <= t_max) # Faster than numpy for large array + mask = reduce_all(np.abs(out_tij[:-1] - out_tij[-1:]) <= t_max) # Faster than numpy for large array out_tij = np.compress(mask, out_tij, axis=1) idx1, idx2 = idx1[mask], idx2[mask] out_val = mem1[0][idx1] + mem2[0][idx2] tij_dep = out_tij[:-1] - out_tij[-1:] tij_dep *= np.array([1 if i > id2 else -1 for i in id1])[:, np.newaxis] - ch_dep = np.array([num_ind(i, id2, n_ind) if i > id2 else num_ind(id2, i, n_ind) for i in id1]) + ch_dep = np.array([num_ind_dep(i, id2, n_ind) if i > id2 else num_ind_dep(id2, i, n_ind) for i in id1]) out_val += cc[ch_dep[:, np.newaxis], tij_dep].sum(0) return out_val, out_tij @@ -132,12 +150,19 @@ def constrained_argmax(mem, cc, tij_ind, curr_tij, used_tij, t_max, n_ind): for u in used_tij: for i, c in enumerate(curr_tij): if u < c: - mem_val += cc[num_ind(c, u, n_ind), mem_tij[i] - tij_ind[u]] + mem_val += cc[num_ind_dep(c, u, n_ind), mem_tij[i] - tij_ind[u]] else: - mem_val += cc[num_ind(u, c, n_ind), tij_ind[u] - mem_tij[i]] + mem_val += cc[num_ind_dep(u, c, n_ind), tij_ind[u] - mem_tij[i]] return mem_tij[:, np.argmax(mem_val)] +def truncated_argmax(cc, t_max): + argmax_plus = np.argmax(cc[:, :t_max], axis=1) + argmax_minus = np.argmax(cc[:, -t_max:], axis=1) - t_max + x = np.arange(len(argmax_minus)) + return np.where(cc[x, argmax_minus] < cc[x, argmax_plus], argmax_plus, argmax_minus) + + def _get_mem_size(memory): return sum(len(o[0].T) for o in memory.values()) @@ -150,45 +175,76 @@ def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list, verbose=False memory = dict() with np.errstate(divide='ignore'): cc = np.log10(cc) - val = cc[:, 0].sum() + out_val = cc[:, 0].sum() tij = np.zeros(n_tot, int) + out_tij = tij.copy() + tij[:n_ind] = truncated_argmax(cc[:n_ind], t_max) + dep_tdoa(tij, n_ind, n_tot) + if np.all(np.abs(tij) <= t_max): + val = cc[np.arange(n_tot), tij].sum() + if val > out_val: + out_tij = tij.copy() + out_val = val + if out_val < -1: # out_val < -20dB, trying with other channel as independent base + _sort = (np.argsort(np.asarray(np.triu_indices(n_ind + 1, 1)).ravel()) % tij.size)[n_ind:] + # _sort should be precomputed since fixed by n_ind, but is only 40µs for n_ind =4 + ch = np.argmin(cc[np.arange(n_tot), tij][_sort].reshape(-1, n_ind).sum(1)) + other_ch = np.asarray([i for i in range(n_ind) if i != ch]) + dep = [num_ind_dep(i, ch, n_ind) if i > ch else num_ind_dep(ch, i, n_ind) for i in other_ch] + t_dep = truncated_argmax(cc[dep], t_max) + tij[other_ch] = (-1)**(other_ch < ch) * t_dep + tij[ch] + dep_tdoa(tij, n_ind, n_tot) + if np.all(np.abs(tij) <= t_max): + val = cc[np.arange(n_tot), tij].sum() + if val > out_val: + out_tij = tij.copy() + out_val = val # if this point is reach, it might be useful to swap(0, ch) for i, step in enumerate(program): # increase dimensions + for k, v in memory.items(): + curr_tij = frozenset(tree[k[0]][k[1]]) + memory[k] = mask_val(v, out_val - sum(m for tij, m in max_val.items() if curr_tij.isdisjoint(tij))) + if i and np.prod([v[0].size for v in memory.values()], dtype=np.float64) <= 1024: # float to prevent overflow + return *add_all(memory, tree, cc, t_max, n_ind, n_tot), 1 for j, op in enumerate(step): if op == 'mem': if i == 0: - memory[(0, j)] = mask_val((np.concatenate((cc[op.right, -t_max:], cc[op.right, :t_max + 1])), - np.arange(-t_max, t_max + 1, dtype=np.int32)[np.newaxis]), val) + memory[0, j] = mask_val((np.concatenate((cc[op.right, -t_max:], cc[op.right, :t_max + 1])), + np.arange(-t_max, t_max + 1, dtype=np.int32)[np.newaxis]), out_val) else: - memory[(i, j)] = mask_val(memory[(op.left, op.right)], val) + memory[i, j] = memory[op.left, op.right] else: # op == 'mul' - memory[(i, j)] = mul(mask_val(memory[(i-1, op.left)], val), mask_val(memory[(0, op.right)], val), - cc, t_max, tree[i - 1][op.left], tree[0][op.right][0], n_ind, - mem_limit=mem_limit - _get_mem_usage(memory)) - if memory[(i, j)] is None: # means that the memory limit has been reach - return val, tij, 0 + memory[i, j] = add(memory[i - 1, op.left], memory[0, op.right], cc, t_max, tree[i - 1][op.left], + tree[0][op.right][0], n_ind, mem_limit=mem_limit - _get_mem_usage(memory)) + if memory[i, j] is None: # means that the memory limit has been reach + return out_val, out_tij, 0 # find potential maximum tij[:] = 0 done_tij = set() + max_val = {frozenset(tree[i][j]): memory[i, j][0].max() for j in range(len(step))} try: for j in range(len(step)): curr_tij = tree[i][j] - tij[curr_tij] = constrained_argmax(memory[(i, j)], cc, tij[:n_ind], curr_tij, done_tij, t_max, n_ind) + tij[curr_tij] = constrained_argmax(memory[i, j], cc, tij[:n_ind], curr_tij, done_tij, t_max, n_ind) done_tij.update(curr_tij) dep_tdoa(tij, n_ind, n_tot) - val = max(cc[np.arange(n_tot), tij].sum(), val) + val = cc[np.arange(n_tot), tij].sum() + if val > out_val: + out_tij = tij.copy() + out_val = val except ValueError: # search of potential maxima ended outside of possible values - tij_min, tij_max = memory[(i, j)][1].min(), memory[(i, j)][1].max() + tij_min, tij_max = memory[i, j][1].min(), memory[i, j][1].max() for k in range(j): - memory[(i, k)] = mask_lim(memory[(i, k)], tij_min, tij_max, t_max) + memory[i, k] = mask_lim(memory[i, k], tij_min, tij_max, t_max) if verbose: mem_size = _get_mem_size(memory) - tqdm.write(f'TDOA: {tij}, val: {20*val:7.3f}dB, mem size: {mem_size} items, {_get_mem_usage(memory):3.2e} octets,' - f' {100 * mem_size / (n_ind // (i + 1)) / (2 * t_max + 1) ** (i + 1):.4}%') + tqdm.write(f'TDOA: {out_tij}, val: {20*out_val:7.3f}dB, mem size: {mem_size} items, ' + f'{_get_mem_usage(memory):3.2e} octets, ' + f'{100 * mem_size / (n_ind // (i + 1)) / (2 * t_max + 1) ** (i + 1):.4}%') # Mem clean up for p in clean_list[i]: del memory[p] - return val, tij, 1 + return out_val, out_tij, 1