#!/usr/bin/env python3
from collections import deque
from copy import deepcopy
from functools import reduce
from itertools import count
import tkinter as TK

# ################################################################
# General utilities

class Gensym:
    """
    Poor man's closure to generate unique names.
    """
    def __init__(self, prefix="g", seq=count()):
        self.prefix = prefix
        self.seq = seq
    def __call__(self):
        return "{}{}".format(self.prefix, next(self.seq))

gensym = Gensym()
# gensym() returns successively "g0", "g1", "g2", etc.

def dict_filter(dictionary, keys):
    """
    Return a subdictionary of `dictionary` comprising only the given `keys`.
    """
    return dict((key, dictionary[key]) for key in keys)

def integers(n):
    """
    Return the list of integers from 0 (included) to n-1 (included).
    """
    return list(range(n))

def flatten(l):
    """
    Given a list of list, return the list of elements of sublists.
    """
    return [item for sublist in l for item in sublist]

# ################################################################
# Automata network themselves

class Node:
    """
    An automata network node.

    Constructor fields:
    name         -- Any hashable value, used for printing and identification
                    (must be unique inside an AN or all hell breaks loose)
    update_func  -- Update Function; it must take one parameter: a dictionary
                    mapping name of neighbours to their values
    dependencies -- List of dependency names (other nodes' names)
    init_state   -- The initial value of the node (may be any value)

    Other fields:
    state -- Curent value of the node
    """
    def __init__(self, name, update_func=None, dependencies=[], init_state=0):
        self.name = name
        self.update_func = update_func
        self.dependencies = dependencies
        self.init_state = init_state
        self.reset_state() # Sets self.state
    
    def reset_state(self):
        self.state = self.init_state
    
    def add_dependency(self, new_dependency):
        self.dependencies.append(new_dependency)
    
    def update(self, neigh_vals):
        self.state = self.update_func(neigh_vals)

    def __repr__(self):
        return "{{node {}: {}}}".format(self.name, self.state)

class AN:
    """
    An automata network.

    Constructor fields:
    nodes       -- A set (or list…) of nodes.
    update_mode -- A list of sets (or set-like objects) of node names, viewed
                   as a periodic update mode.
    history_len -- How many states to remember for the rollback feature.
                   This field may be altered later, and the history will be
                   resized accordingly.

    Other fields:
    up      -- Stands for Update Pointer.
               It is an index in update_mode that “points” to the next update
               to perform. Note that foo.up can be publicly read and written.
               A modulo len(update_mode) is implicitly applied, so it is safe
               to do, e.g., foo.up +=1 on an update.
    history -- A circular buffer of deep copies of the previous values of _nodes
               (see below).
               It can be publicly accessed. Please do not alter the history
               by yourself (Stalin did nothing right).
    _nodes  -- A dictionary mapping node names to nodes.
               When the field `nodes` is read, it is actually generated
               on-the-fly as the set of items of _nodes. Conversely when
               `nodes` is written, a dictionary is implicitly generated and
               written to `_nodes`.
    """

    def __init__(self, nodes={}, update_mode=[], history_len=100):
        self.nodes = nodes
        self.update_mode = update_mode
        self.up = 0
        self.history = deque(maxlen=history_len)
    
    # ################################################################
    
    @property
    def nodes(self):
        """Return the set of all nodes."""
        return self._nodes.values()

    @nodes.setter
    def nodes(self, new_nodes):
        self._nodes = {}
        for node in new_nodes: self.add_node(node)
    
    @property
    def state(self):
        """Return a dictionary mapping node names to node states."""
        return dict((name, self._nodes[name].state) for name in self._nodes.keys())

    def add_node(self, node):
        """Adds a new node to the AN."""
        self._nodes[node.name] = node

    def del_node(self, node_name):
        """
        Remove a node from the AN (by node name). Returns the deleted node.
        If no node with such name is found, raises KeyError.
        """
        return self._nodes.pop(node_name)
    
    @property
    def up(self):
        return self._up
    
    @up.setter
    def up(self, new_up):
        self._up = new_up % len(self.update_mode)
    
    @property
    def history_len(self):
        return len(self.history)
    
    @history_len.setter
    def history_len(self, new_history_len):
        new_history = deque(maxlen=new_history_len)
        for i in range(min(new_history_len, self.history_len)):
            new_history.append(self.history[-(i+1)])
        self.history = new_history
    
    def __repr__(self):
        result = "{AN \n"
        for node in self.nodes:
            result += "    " + str(node) + "\n"
        result += "}"
        return result

    # ################################################################
    
    def reset_state(self):
        """Set each node to its initial value, and the up to 0."""
        new_nodes = deepcopy(self._nodes)
        for node in new_nodes:
            node.reset_state()
        self.history.append(self._nodes)
        self._nodes = new_nodes
        self.up = 0

    def update(self):
        """Update the AN once."""
        new_nodes = deepcopy(self._nodes)
        for n in self.update_mode[self.up]:
            local_view = dict_filter(self.state, self._nodes[n].dependencies)
            new_nodes[n].update(local_view)
        self.history.append(self._nodes)
        self._nodes = new_nodes
        self.up += 1

    def unupdate(self):
        """Roll back one step of the AN, within limits of history_len."""
        self._nodes = self.history.pop()
        self.up -= 1
    
    def run(self, steps):
        """Update the AN `steps` times. If `steps` is negative, roll back."""
        if(steps < 0):
            for i in range(-steps):
                self.unupdate()
        else:
            for i in range(steps):
                self.update()

# ################################################################
# Update modes

def parallel_mode(node_names):
    """Return a parallel update mode."""
    return [node_names]

def local_clocks(period, deltas):
    """
    Return a local clocks update mode.

    period -- The global period update.
    deltas -- A dictionary mapping node names to phases.
    """
    result = [] * global_periods
    for k in deltas.keys():
        result[deltas[k]].append(k)
    return result

def is_block_seq(mode):
    """
    Return whether an update mode is block sequential
    (no node is updated twice).
    """
    already_seen = set()
    for node_name in flatten(mode):
        if node_name in already_seen: return False
        else: already_seen.add(name)
    return True

def block_seq(mode):
    """
    Checks that `mode` is block-sequential. If so, return it.
    If not, raise a ValueError exception. Useful for functional programming.
    """
    if not is_block_seq(mode): raise ValueError()
    return mode

def periodic_mode(mode):
    """
    The identity function, with another name.
    Useful for functional programming.
    """
    return mode

# ################################################################
# Node construction

def node_fixpoint(name, init_val=0):
    """
    Return a node that depends only on itself and updates to its own value.
    """
    def identity(args):
        for k in args.keys():
            return args[k]
    return Node(name, identity, [name], init_val)

def node_constant(name, value=0, init_val=None):
    """
    Return a node that depends on nobody and update to the given value.
    """
    if not init_val: init_val=value
    
    def constant(args):
        nonlocal value
        return value
    
    return Node(name, constant, [], init_val)

# ################################################################
# Cycle ANs

def cycle_update_func(sign=True):
    """
    Return an update function for a node in a cycle.
    All update functions assume only one neighbour.
    If sign=True, the returned update function is identity.
    If sign=False, the returned update function is negation.
    """
    def positive(neighs):
        for k in neighs.keys():
            return neighs[k]
    
    def negative(neighs):
        for k in neighs.keys():
            return not neighs[k]

    if sign: return positive
    else:    return negative

def cycle(signs, init_vals=None, update_mode=None):
    """
    Return a cyclic AN.
    Node names are 0,…,n-1, with n=len(signs). Node i depends on node i-1 mod n.
    signs       -- signs[i] is True if i depends positively on i-1, False if
                   negatively
    init_vals   -- list of initial values (booleans); defaults to all True
    update_mode -- a periodic update mode; defaults to parallel
    """
    num_nodes = len(signs)
    node_names = integers(num_nodes)
    
    if not init_vals:
        init_vals = [True]*num_nodes
    if not update_mode:
        update_mode = parallel_mode(node_names)
    
    nodes = []
    for (sign, init_val, i) in zip(signs, init_vals, count()):
        update_func = cycle_update_func(sign)
        node = Node(i, update_func, [(i-1)%num_nodes], init_val)
        nodes.append(node)
    return AN(nodes, update_mode)

# ################################################################
# And-not networks

def land(b1, b2):
    """Like `and`, but a function (usable with `reduce`)"""
    return b1 and b2

def and_not_update_func(dep_list):
    """
    Return an update function for a node in an and_not network.
    If dep_list is [1, -2, 3, -4], then the node has positive neighbours 1, 3
    and negative neighbours 2, 4.
    """
    def node_update_func(args):
        nonlocal dep_list
        l = []
        for n in dep_list:
            if n<0: l.append(not args[-n])
            else: l.append(args[n])
        return reduce(land, l)
    return node_update_func

def and_not(dep_lists, init_vals, update_mode):
    """
    Return an and-not network.
    Nodes are named 1, …, N where N=len(dep_lists)==len(init_vals).

    dep_lists   -- dep_lists[0] is ignored; dep_lists[i] is the list of
                   neighbours of i (so a list of integers).
                   A negative neighbour means a negated neighbour.

    For example if dep_list[1] contains 4, then 1 depends positively on 4;
    if dep_list[1] contains -4, then 1 depends negatively on 4.

    init_vals   -- init_vals[0] is ignored; init_vals[i] is the initial value of
                   node i
    update_mode -- periodic update mode, i.e. a list of sets of integers
    """
    nodes = []
    for (dep_list, init_val, i) in zip(dep_lists[1:], init_vals[1:], count(1)):
        update_func = and_not_update_func(dep_list)
        deps = [abs(d) for d in dep_list]
        node = Node(i, update_func, deps, init_val)
        nodes.append(node)
    return AN(nodes, update_mode)

# ################################################################

example_deps = [
    None, #ignored
    [3, 5],
    [4, 6],
    [1, 7],
    [2, 7],
    [1, -8],
    [2, -8],
    [3, 4, 9],
    [-5, -6, -9, -10],
    [-8, 13, 14],
    [11, 12],
    [10],
    [10],
    [9, 14],
    [9, 13],
]

example_init_vals = [
    None, #ignored
]

example_update = [
    [],
]