import numpy as np from math import ceil class Operation: __slots__ = 'op', 'left', 'right', 'lifetime' def __init__(self, op, left, right, lifetime): self.op = op self.left = left self.right = right self.lifetime = lifetime def __repr__(self): return f'(op: {self.op}, l: {self.left}, r: {self.right}, lt: {self.lifetime})' def __eq__(self, other): if isinstance(other, str): return self.op == other elif isinstance(other, Operation): return self.op == other.op and self.left == other.left and\ self.right == other.right and self.lifetime == other.lifetime return False def gen_tree(size): tree = [list() for _ in range(size)] tree[-1].append(list(range(size))) for i in range(size-2, -1, -1): m_len = i + 1 unused = [] for group in tree[i+1]: if len(group) > m_len: tree[i].append(group[:-1].copy()) unused.append(group[-1]) else: tree[i].append(group.copy()) for group in tree[i]: le = len(group) if le < m_len: group.extend(unused[:m_len-le]) unused = unused[m_len-le:] if not len(unused): break if len(unused): for j in range(ceil(len(unused)/m_len)): tree[i].append(unused[j*m_len:(j+1)*m_len]) return tree def op_tree(tree): program = [list() for _ in range(len(tree))] clean_list = [list() for _ in range(len(tree) + 1)] program[0] = [Operation('mem', 0, group[0], len(tree)) for group in tree[0]] for i in range(1, len(tree)): for group in tree[i]: if group[:-1] in tree[i-1]: j = tree[i-1].index(group[:-1]) program[i-1][j].lifetime = i program[i].append(Operation('mul', j, tree[0].index(group[-1:]), i)) else: for j in range(i-1, -1, -1): if group in tree[j]: break else: raise KeyError(f'{group} not found in:\n{tree}') program[i].append(Operation('mem', j, tree[j].index(group), i)) for i, step in enumerate(program): for j, op in enumerate(step): clean_list[op.lifetime].append((i, j)) return program, clean_list def dep_tdoa(tij, nind, ntot): i0 = 1 j0 = 0 for i in range(nind, ntot): tij[i] = tij[i0] - tij[j0] i0 += 1 if i0 >= nind: j0 += 1 i0 = j0 + 1 def num_ind(i, j, nind): return j*(nind-1) + i-1 - (j*(j+1))//2 def mul(mem1, mem2, cc, t_max, id1, id2, n_ind): # assume len(id2) == 1 idx1, idx2 = np.meshgrid(np.arange(len(mem1[0])), np.arange(len(mem2[0]))) idx1, idx2 = idx1.ravel(), idx2.ravel() out_tij = np.concatenate((mem1[1][:, idx1], mem2[1][:, idx2]), axis=0) mask = (np.abs(out_tij[:-1] - out_tij[-1:]) <= t_max).all(0) out_tij = out_tij[:, mask] idx1, idx2 = idx1[mask], idx2[mask] out_val = mem1[0][idx1] * mem2[0][idx2] tij_dep = out_tij[:-1] - out_tij[-1:] tij_dep *= np.array([1 if i > id2 else -1 for i in id1])[:, np.newaxis] ch_dep = np.array([num_ind(i, id2, n_ind) if i > id2 else num_ind(id2, i, n_ind) for i in id1]) + n_ind out_val *= cc[ch_dep[:, np.newaxis], tij_dep].prod(0) return out_val, out_tij def mask_val(mem, val): mask = mem[0] >= val return mem[0][mask], mem[1][:, mask] def constrained_argmax(mem, tij_ind, t_max): min_t, max_t = tij_ind.min(), tij_ind.max() mask = ((max_t - t_max <= mem[1]) & (mem[1] <= min_t + t_max)).all(0) return mem[1][:, mask][:, np.argmax(mem[0][mask])] def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list): memory = dict() val = 0 tij = np.zeros(n_tot, int) for i, step in enumerate(program): # increase dimensions for j, op in enumerate(step): if op == 'mem': if i == 0: memory[(0, j)] = (np.concatenate((cc[op.right, -t_max:], cc[op.right, :t_max + 1])), np.arange(-t_max, t_max + 1)[np.newaxis]) else: memory[(i, j)] = mask_val(memory[(op.left, op.right)], val) else: # op == 'mul' memory[(i, j)] = mul(mask_val(memory[(i-1, op.left)], val), mask_val(memory[(0, op.right)], val), cc, t_max, tree[i - 1][op.left], tree[0][op.right][0], n_ind) # find potential maximum tij[:] = 0 for j in range(len(step)): tij[tree[i][j]] = constrained_argmax(memory[(i, j)], tij[:n_ind], t_max) dep_tdoa(tij, n_ind, n_tot) val = cc[np.arange(n_tot), tij].prod() # print('tdoa:', tij, 'val:', val, 'mem size:', (lambda x: f'{x} ({100 * x / (2 * t_max + 1) ** (i+1)}%)')(sum(len(o[0].T) for o in memory.values()))) # Mem clean up for p in clean_list[i]: del memory[p] return val, tij