import pickle
import copy
from pygraphviz import *

import src.globs as globs

class Command:
    idn_pool = 0
    def __init__(self, name):
        self.succs = []
        self.name = name
        self.idn = Command.idn_pool
        #if command can be called repeatedly
        self.cycle = False
        self.errnode = False
        self.errmsg = ""
        Command.idn_pool += 1

    def __eq__(self, other):
        return self.idn == other.idn

    def __str__(self):
        if not self.cycle:
            return self.name + " " + str(self.idn)
        else:
            return self.name + "* " + str(self.idn)

    def __repr__(self):
        return self.__str__()

    def add_succ(self, comm):
        self.succs.append(comm)

class Net:
    def __init__(self):
        self.start = Command("_start")
        self.end = Command("_end")
        self.activ = self.start
        self.known_commands = []

    def count_state(self, state):
        self.nodes.append(state)
        for s in state.succs:
            if s not in self.nodes:
                self.count_state(s)

    def all_nodes(self):
        self.nodes = []
        self.activ = self.start
        self.count_state(self.start)
        return self.nodes

    def who_points_at(self, node):
        ret = []
        for n in self.all_nodes():
            if node in n.succs:
                ret.append(n)
        return ret

    def maxidn(self):
        maxval = 0
        for n in self.all_nodes():
            if n.idn > maxval:
                maxval = n.idn
        return maxval

    #Find longest common subsequence in two sequences
    def dynamic(self, L1, L2):
        maxi = 0
        ret = []
        for L in L2:
            matrix = [[0] * len(L) for _ in range(len(L1))]
            for i in range(len(L1)):
                for j in range(len(L)):
                    if L1[i].name == L[j].name:
                        if i == 0 or j == 0:
                            matrix[i][j] = 1
                        else:
                            matrix[i][j] = matrix[i - 1][j - 1] + 1
                        if matrix[i][j] >= maxi:
                            maxi = matrix[i][j]
                            ret = L1[i - maxi + 1:i + 1]
                    else:
                        matrix[i][j] = 0
        #Sequnce of length one is not needed
        if len(ret) == 1:
            ret = []
        return ret

    #Transforms path representation into sequence of states
    def path_to_nodes(self, path):
        activ = self.start
        ret = []
        for p in path:
            ret.append(activ)
            activ = activ.succs[p]
        ret.append(activ)
        return ret

    #Return node object from network
    def get_node(self, path):
        activ = self.start
        for p in path:
            activ = activ.succs[p]
        return activ

    #Return start and end indexes of common nodes from path
    def find_in_path(self, path, common):
        for i, _ in enumerate(path):
            found = True
            for j, c in enumerate(common):
                if path[i + j].name != c.name:
                    found = False
                    break
            if found:
                return (i, i + j)
        return (None, None)

    #Removes common nodes from sequences, repairs connections
    def remove_common(self, commands, common, path):
        #for aa in commands[0]:
        #    print("bbb", aa.succs)
        ret = []
        for c in commands:
            c1 = []
            c2 = []
            for i in range(len(c)):
                found = True
                for j in range(len(common)):
                    if(c[i + j].name != common[j].name):
                        found = False
                        break
                if(found):
                    s, e = self.find_in_path(self.path_to_nodes(path), common)
                    #If common nodes can be cycled add it to network
                    for k in range(len(common)):
                        if c[i + k].cycle:
                            self.get_node(path[:s + k]).cycle = True
                    #Make end of common nodes in network points at rest of 
                    #sequence
                    if i + j + 1 < len(c):
                        self.get_node(path[:e]).add_succ(c[i + j + 1])
                    #If sequence has already been splited and end of current 
                    #subsenquence is common node, make end of network common 
                    #nodes point where it pointed
                    if c[-1].succs and i + j + 1  == len(c):
                        self.get_node(path[:e]).add_succ(c[-1].succs[0])
                    #End of previous non common subsequence should point
                    #at start of common seuquence in network
                    if i > 0:
                        c[i - 1].succs[0] = self.get_node(path[:s])
                    #All nodes in network who pointed at start of common
                    #sequence should now point at start of common sequence
                    #in network
                    for pred in self.who_points_at(c[i]):
                        for k in range(len(pred.succs)):
                            if pred.succs[k] == c[i]:
                                pred.succs[k] = self.get_node(path[:s]) 
                        
                    c[i: i + j + 1] = []
                    c1 = c[:i]
                    c2 = c[i:]
                    break

            if (not c1) and (not c2):
                ret.append(c)
            else:
                if c1:
                    ret.append(c1)
                if c2:
                    ret.append(c2)
        
        return ret

    #Generates all possible paths through network
    def generate_paths(self, node, ppath):
        if not node.succs:
            return [[]]
        else:
            paths = [] 
            npaths = []
            #If there is cycle in FSM stop generating
            if ppath.count(node) == 0: 
                pppath = list(ppath)
                pppath.append(node)
                for i, s in enumerate(node.succs):
                    npaths += self.generate_paths(s, pppath)
                    for j, p in enumerate(npaths):
                        if not npaths[j]:
                            npaths[j].append(i)
                        else:
                            npaths[j].insert(0, i)
                    paths += npaths
                    npaths = []
                        
            return paths

    #Converts list of commands string into commands objects
    def str_to_nodes(self, commands):
        prev = Command(commands[0])
        ret = [prev]
        for cmd in commands[1:]:
            if cmd != prev.name:
                c = Command(cmd)
                prev.add_succ(c)
                prev = c
                ret.append(c)
            else:
               prev.cycle = True 
        return ret
            
    def insert(self, commands):
        cmds = copy.copy(commands)
        self.known_commands.extend(cmds)
        self.known_commands = list(set(self.known_commands))
        if not self.start.succs:
            #Network is empty
            path = self.str_to_nodes(cmds)
            self.start.add_succ(path[0])
            path[-1].add_succ(self.end)
        else:
            cmds.insert(0, "_start")
            cmds.append("_end")
            cmds = [self.str_to_nodes(cmds)]
            paths = self.generate_paths(self.start, [])
            #Search for longest common sequence and then again in 
            #new subsequences
            while True:
                found = False
                common = []
                lpath = []
                for p in paths:
                    tmp = self.dynamic(self.path_to_nodes(p), cmds)
                    if len(tmp) > len(common):
                        common = tmp
                        lpath = p
                        found = True
                if not found:
                    #If there are more _start and _end nodes in network, merge them
                    if cmds and cmds[0]:
                        if cmds[0][0].name == "_start":
                            self.start.add_succ(cmds[0][0].succs[0])
                        if cmds[-1][-1].name == "_end" and len(cmds[-1]) == 1:
                            pred = self.who_points_at(cmds[-1][-1])[0]
                            pred.add_succ(self.end)
                            pred.succs.remove(cmds[-1][-1])
                        if cmds[-1][-1].name == "_end" and len(cmds[-1]) > 1:
                            cmds[-1][-2].succs[0] = self.end
                    break
                cmds = self.remove_common(cmds, common, lpath)

    def remove_pair_from_list(self, pair, lst):
        for i, p in enumerate(lst):
            if (p[0].idn == pair[0].idn) and (p[1].idn == pair[1].idn):
                return lst[:i] + lst[(i + 1):]

    def gen_pairs(self):
        pairs = []
        for node in self.all_nodes():
            for succs in node.succs:
                pairs.append((node, succs))
        #Only pairs that occur multiple times
        ret = []
        for p1 in pairs:
            self.remove_pair_from_list(p1, pairs)
            for p2 in pairs:
                if p1[0].name == p2[0].name and p1[1].name == p2[1].name:
                    if (p1[0].idn != p2[0].idn) or (p1[1].idn != p2[1].idn):
                        if p1 not in ret:
                            ret.append(p1)
                        if p2 not in ret:
                            ret.append(p2)
        return ret

    #It may happend that some same subsequences are not merged, so fix that
    def merge_same(self):
        pairs = self.gen_pairs()
        while(pairs):
            p1 = pairs[0]
            for p2 in pairs[1:]:
                if p1[0].name == p2[0].name and p1[1].name == p2[1].name:
                    if p1[0] != p2[0]:
                        for pred in self.who_points_at(p2[0]):
                            pred.succs.append(p1[0])
                            pred.succs.remove(p2[0])
                        p1[0].succs += p2[0].succs
                    if p1[1] != p2[1]:
                        for pred in self.who_points_at(p2[1]):
                            pred.succs.append(p1[1])
                            pred.succs.remove(p2[1])
                        p1[1].succs += p2[1].succs
            pairs = self.gen_pairs()
    
    def draw(self, args):
        drawing = AGraph(directed = True)
        drawing.node_attr["shape"] = "none"
        drawing.graph_attr["rankdir"] = "LR"

        for node in self.all_nodes():
            if node.cycle:
                drawing.add_edge(node.idn, node.idn)
            for succ in node.succs:
                drawing.add_node(node.idn, label = node.name)
                drawing.add_node(succ.idn, label = succ.name)
                drawing.add_edge(node.idn, succ.idn)
                if succ.errnode:
                    errlbl =    "<<FONT COLOR=\"RED\">" + succ.name + \
                                "<BR/>" + succ.errmsg + "</FONT>>"
                    tmp = drawing.get_node(succ.idn)
                    tmp.attr["label"] = errlbl
                    tmp = drawing.get_edge(node.idn, succ.idn)
                    tmp.attr["color"] = "red"
                
        drawing.layout("dot")
        drawing.draw(globs.fdir + "/output/" + args.trid + ".method2.pdf")

def train_correct(commands, args):
    net = Net()
    for c in commands:
        net.insert(c)
    net.merge_same()
    net.draw(args)
    with open(globs.fdir + "/output/" + args.trid + ".method2", 'wb') as fp:
        pickle.dump(net, fp)
    return net

def train_incorrect(commands, args):
    with open(globs.fdir + "/output/" + args.trid + ".method2", 'rb') as fp:
        net = pickle.load(fp)
    #Restore class variable
    Command.idn_pool = net.maxidn() + 1
    commands = net.str_to_nodes(commands)

    activ = net.start
    for c in commands:
        correct_command = False
        for s in activ.succs:
            if c.name == s.name and not (c.cycle and not s.cycle):
                activ = s
                correct_command = True
                break
        if not correct_command:
            if c.name not in net.known_commands:
                net.known_commands.append(c.name)
            c.succs = []
            c.errnode = True
            c.errmsg = args.errmsg
            activ.add_succ(c)
            break
        if activ.errnode:
            activ.errmsg += " OR " + args.errmsg
            break
    net.draw(args)
    with open(globs.fdir + "/output/" + args.trid + ".method2", "wb") as fp:
        pickle.dump(net, fp)
   
def check(commands, args):
    with open(globs.fdir + "/output/" + args.trid + ".method2", 'rb') as fp:
        net = pickle.load(fp)
    commands = net.str_to_nodes(commands)
    activ = net.start

    last_unknown = False #Indicator if last command was unknown
    for i, c in enumerate(commands):
        found = False
        for s in activ.succs:
            if c.name == s.name:
                if s.errnode:
                    raise RuntimeError("Method 2: Error found - " + s.errmsg)
                    return False
                #if command in unknown sequence was repeated,
                #model has to allow it
                if not(c.cycle and not s.cycle):
                    last_unknown = False
                    activ = s
                    found = True
                    break
        if not found:
            if c.name in net.known_commands and not last_unknown:
                raise RuntimeError("Method 2: Unknown command")
                return False

            if (i + 1) == len(commands):
                return True

            #Check if uknown command has replaced known command
            for s1 in activ.succs:
                for s2 in s1.succs:
                    #Check if we dont consider inserted command as replaced
                    for stmp in activ.succs:
                        if stmp.name == s2.name:
                            found = True
                    if s2.name == commands[i + 1].name and not found:
                        activ = s1
                        found = True
                        break
            
            last_unknown = True
    return True
