import pickle
from pygraphviz import *

import src.globs as globs

class Command:
    idn_pool = 0
    def __init__(self, name, depth):
        self.succs = []
        self.name = name
        self.depth = depth
        self.idn = Command.idn_pool
        self.errnode = False
        self.errmsg = ""
        Command.idn_pool += 1

    def __str__(self):
        return self.name

    def __repr__(self):
        return str(self.name)

    def add_succ(self, comm):
        self.succs.append(comm)

class Tree:
    def __init__(self, nseq):
        self.start = Command("_start", 0)
        self.seqpos = [self.start] * nseq
        self.known_commands = [self.start]

    def max_depth(self):
        maxim = 0
        for n in self.known_commands:
            if n.depth > maxim:
                maxim = n.depth
        return maxim

    def max_idn(self):
        maxval = 0
        for n in self.known_commands:
            if n.idn > maxval:
                maxval = n.idn
        return maxval

    def print_state(self, state):
        self.to_print.append(state)
        print(state, state.idn, state.depth, state.succs)
        for s in state.succs:
            if s not in self.to_print:
                self.print_state(s)

    def _print(self):
        self.to_print = []
        self.print_state(self.start)

    def nodes_in_depth(self, depth):
        ret = []
        for n in self.known_commands:
            if n.depth == depth:
                ret.append(n)
        return ret

    def insert(self, command, seq):
        pred = self.seqpos[seq]
        pred_level = self.nodes_in_depth(pred.depth)
        for pl in pred_level:
            if pl.name == command:
                if not pl in pred.succs:
                    pred.add_succ(pl)
                self.seqpos[seq] = pl
                return
        for i in range(1, self.max_depth() - pred.depth + 1):
            next_level = self.nodes_in_depth(pred.depth + i)
            for nl in next_level:
                if nl.name == command:
                    if not nl in pred.succs:
                        pred.add_succ(nl)
                    self.seqpos[seq] = nl
                    return
        nc = Command(command, pred.depth + 1)
        self.known_commands.append(nc)
        self.seqpos[seq] = nc
        if not nc in pred.succs:
            pred.add_succ(nc)

    def draw(self, args):
        drawing = AGraph(directed = True)
        drawing.node_attr["style"] = "filled"
        drawing.node_attr["forcelabels"] = True
        drawing.graph_attr["ranksep"] = 1.5

        nodes_levels = [[] for _ in range(self.max_depth() + 1)]
        for node in self.known_commands:
            nodes_levels[node.depth].append(node)

        for level in nodes_levels:
            for node in level:
                drawing.add_node(   n = node.idn, 
                                    label = node.name, 
                                    shape = "circle")
                if node.errnode:
                    tmp = drawing.get_node(node.idn)
                    tmp.attr["fillcolor"] = "red"
                    errlabel = "<" + node.name + "<BR/>" + node.errmsg + ">"
                    tmp.attr["label"] = errlabel
            level_idns = (n.idn for n in level)
            drawing.add_subgraph(level_idns, rank = "same")
        for node in self.known_commands:
            for succ in node.succs:
                drawing.add_edge(node.idn, succ.idn)
                if succ.errnode:
                    tmp = drawing.get_edge(node.idn, succ.idn)
                    tmp.attr["color"] = "red"
                
        drawing.layout("dot")
        drawing.draw(globs.fdir + "/output/" + args.trid + ".method3.pdf")

def train_correct(commands, args):
    tree = Tree(len(commands))
    for i in range(len(max(commands, key = len))):
        for j, cmds in enumerate(commands):
            if i < len(cmds):
                tree.insert(cmds[i], j)
    tree.draw(args)
    with open(globs.fdir + "/output/" + args.trid + ".method3", 'wb') as fp:
        pickle.dump(tree, fp)
    return tree

def train_incorrect(commands, args):
    with open(globs.fdir + "/output/" + args.trid + ".method3", "rb") as fp:
        tree = pickle.load(fp)
    Command.idn_pool = tree.max_idn() + 1
    activ = tree.start
    for c in commands:
        found = False
        for s in activ.succs:
            if c == s.name:
                activ = s
                found = True
                break
        if not found:
            cmd = Command(c, activ.depth)
            cmd.errnode = True
            cmd.errmsg = args.errmsg
            activ.add_succ(cmd)
            tree.known_commands.append(cmd)
            break
        if activ.errnode:
            activ.errmsg += " OR " + args.errmsg
            break
    tree.draw(args)
    with open(globs.fdir + "/output/" + args.trid + ".method3", 'wb') as fp:
        pickle.dump(tree, fp)

def check(commands, args):
    with open(globs.fdir + "/output/" + args.trid + ".method3", 'rb') as fp:
        tree = pickle.load(fp)
    activ = tree.start

    last_unknown = False
    for i, c in enumerate(commands):
        found = False
        for s in activ.succs:
            if c == s.name:
                if s.errnode:
                    raise RuntimeError("Method 3: Error found - " + s.errmsg)
                    return False
                activ = s
                last_unknown = False
                found = True
                break
        if not found:
            if c in str(tree.known_commands) and not last_unknown:
                raise RuntimeError("Method 3: Unknown command")
                return False

            #Last commands can be inserted or replaced
            if (i + 1) == len(commands):
                return True

            #We can skip one replaced command
            for s1 in activ.succs:
                for s2 in s1.succs:
                    #If we are skipping unknown inserted sequence
                    #dont change active state, until we reache first
                    #knon command
                    if s2.name == commands[i + 1] and not found:
                        activ = s1
                        found = True
                        break

            last_unknown = True
    return True
