import numpy as np
from math import ceil


class Operation:
    __slots__ = 'op', 'left', 'right', 'lifetime'

    def __init__(self, op, left, right, lifetime):
        self.op = op
        self.left = left
        self.right = right
        self.lifetime = lifetime
        
    def __repr__(self):
        return f'(op: {self.op}, l: {self.left}, r: {self.right}, lt: {self.lifetime})'
        
    def __eq__(self, other):
        if isinstance(other, str):
            return self.op == other
        elif isinstance(other, Operation):
            return self.op == other.op and self.left == other.left and\
                   self.right == other.right and self.lifetime == other.lifetime
        return False


def gen_tree(size):
    tree = [list() for _ in range(size)]
    tree[-1].append(list(range(size)))
    for i in range(size-2, -1, -1):
        m_len = i + 1
        unused = []
        for group in tree[i+1]:
            if len(group) > m_len:
                tree[i].append(group[:-1].copy())
                unused.append(group[-1])
            else:
                tree[i].append(group.copy())
        for group in tree[i]:
            le = len(group)
            if le < m_len:
                group.extend(unused[:m_len-le])
                unused = unused[m_len-le:]
            if not len(unused):
                break
        if len(unused):
            for j in range(ceil(len(unused)/m_len)):
                tree[i].append(unused[j*m_len:(j+1)*m_len])
    return tree


def op_tree(tree):
    program = [list() for _ in range(len(tree))]
    clean_list = [list() for _ in range(len(tree) + 1)]
    program[0] = [Operation('mem', 0, group[0], len(tree)) for group in tree[0]]
    for i in range(1, len(tree)):
        for group in tree[i]:
            if group[:-1] in tree[i-1]:
                j = tree[i-1].index(group[:-1])
                program[i-1][j].lifetime = i
                program[i].append(Operation('mul', j, tree[0].index(group[-1:]), i))
            else:
                for j in range(i-1, -1, -1):
                    if group in tree[j]:
                        break
                else:
                    raise KeyError(f'{group} not found in:\n{tree}')
                program[i].append(Operation('mem', j, tree[j].index(group), i))
    for i, step in enumerate(program):
        for j, op in enumerate(step):
            clean_list[op.lifetime].append((i, j))
    return program, clean_list


def dep_tdoa(tij, nind, ntot):
    i0 = 1
    j0 = 0
    for i in range(nind, ntot):
        tij[i] = tij[i0] - tij[j0]
        i0 += 1
        if i0 >= nind:
            j0 += 1
            i0 = j0 + 1


def num_ind(i, j, nind):
    return j*(nind-1) + i-1 - (j*(j+1))//2


def mul(mem1, mem2, cc, t_max, id1, id2, n_ind):
    # assume len(id2) == 1
    idx1, idx2 = np.meshgrid(np.arange(len(mem1[0])), np.arange(len(mem2[0])))
    idx1, idx2 = idx1.ravel(), idx2.ravel()
    out_tij = np.concatenate((mem1[1][:, idx1], mem2[1][:, idx2]), axis=0)
    mask = (np.abs(out_tij[:-1] - out_tij[-1:]) <= t_max).all(0)
    out_tij = out_tij[:, mask]
    idx1, idx2 = idx1[mask], idx2[mask]
    out_val = mem1[0][idx1] * mem2[0][idx2]
    tij_dep = out_tij[:-1] - out_tij[-1:]
    tij_dep *= np.array([1 if i > id2 else -1 for i in id1])[:, np.newaxis]
    ch_dep = np.array([num_ind(i, id2, n_ind) if i > id2 else num_ind(id2, i, n_ind) for i in id1]) + n_ind
    out_val *= cc[ch_dep[:, np.newaxis], tij_dep].prod(0)
    return out_val, out_tij


def mask_val(mem, val):
    mask = mem[0] >= val
    return mem[0][mask], mem[1][:, mask]


def constrained_argmax(mem, tij_ind, t_max):
    min_t, max_t = tij_ind.min(), tij_ind.max()
    mask = ((max_t - t_max <= mem[1]) & (mem[1] <= min_t + t_max)).all(0)
    return mem[1][:, mask][:, np.argmax(mem[0][mask])]


def smart_gsrp(cc, n_ind, n_tot, t_max, tree, program, clean_list):
    memory = dict()
    val = 0
    tij = np.zeros(n_tot, int)
    for i, step in enumerate(program):
        # increase dimensions
        for j, op in enumerate(step):
            if op == 'mem':
                if i == 0:
                    memory[(0, j)] = (np.concatenate((cc[op.right, -t_max:], cc[op.right, :t_max + 1])),
                                      np.arange(-t_max, t_max + 1)[np.newaxis])
                else:
                    memory[(i, j)] = mask_val(memory[(op.left, op.right)], val)
            else:  # op == 'mul'
                memory[(i, j)] = mul(mask_val(memory[(i-1, op.left)], val), mask_val(memory[(0, op.right)], val),
                                     cc, t_max, tree[i - 1][op.left], tree[0][op.right][0], n_ind)
        # find potential maximum
        tij[:] = 0
        for j in range(len(step)):
            tij[tree[i][j]] = constrained_argmax(memory[(i, j)], tij[:n_ind], t_max)
        dep_tdoa(tij, n_ind, n_tot)
        val = cc[np.arange(n_tot), tij].prod()
        # print('tdoa:', tij, 'val:', val, 'mem size:', (lambda x: f'{x} ({100 * x / (2 * t_max + 1) ** (i+1)}%)')(sum(len(o[0].T) for o in memory.values())))

        # Mem clean up
        for p in clean_list[i]:
            del memory[p]

    return val, tij