import pickle
from transitions.extensions import GraphMachine as Machine
from transitions import core
from pygraphviz import *

import src.globs as globs

OK = 0
FAIL = -1

class ProtocolFSM(Machine):
    def __init__(self, states, trans):
        Machine.__init__(   self, 
                            states = states,
                            transitions = trans,
                            initial = "_start",
                            auto_transitions = False)

        self.known_commands = states
        self.error_strings = []
    def check_error(self, transition):
        for t, msg in self.error_strings:
            if t == transition:
                return msg
        return 0

    def add_error(self, transition, errmsg):
        self.error_strings.append((transition, errmsg))

    def modify_error(self, transition, errmsg):
        for t, msg in self.error_strings:
            if t == transition:
                self.error_strings.remove((t, msg))
        self.add_error(transition, errmsg)

    def draw(self, args):
        drawing = AGraph(directed = True)
        drawing.node_attr["style"] = "filled"
        drawing.node_attr["shape"] = "circle"
        
        for state in self.states:
            trg = self.get_triggers(state)
            for succ in trg:
                drawing.add_edge(state, succ)
                msg = self.check_error((state, succ))
                if msg != 0:
                    tmp = drawing.get_node(succ)
                    tmp.attr["fillcolor"] = "red"
                    errlabel = "<" + succ + "<BR/>" + msg + ">"
                    tmp.attr["label"] = errlabel
                    tmp = drawing.get_edge(state, succ)
                    tmp.attr["color"] = "red"

        drawing.layout("dot")
        drawing.draw(globs.fdir + "/output/" + args.trid + ".method1.pdf")

def create_transitions(commands):
    trans = []
    prev_c = ''
    for pcap in commands:
        prev_c = '_start'
        for c in pcap:
            trans.append({'trigger': c, 'source': prev_c, 'dest': c})
            prev_c = c
    #remove same transitions
    trans = [dict(tup) for tup in set(tuple(item.items()) for item in trans)]
    return trans

def train_correct(commands, args):
    states = []
    #concatenate commands from all pcaps
    for c in commands:
        states += c
    #remove duplicates
    states = list(set(states))
    states.append("_start")
    trans = create_transitions(commands)
    fsm = ProtocolFSM(states, trans)
    fsm.draw(args)
    with open(globs.fdir + "/output/" + args.trid + ".method1", 'wb') as fp:
        pickle.dump(fsm, fp)
    return fsm

def train_incorrect(commands, args):
    with open(globs.fdir + "/output/" + args.trid + ".method1", 'rb') as fp:
        fsm = pickle.load(fp)

    new_state_created = False
    for c in commands:
        if c not in fsm.get_triggers(fsm.state):
            fsm.known_commands.append(c)
            if c not in fsm.states:
                fsm.add_states(c)
                fsm.add_error((fsm.state, c), args.errmsg)
                new_state_created = True
            fsm.add_transition(c, source = fsm.state, dest = c)
            if new_state_created:
                break
        msg = fsm.check_error((fsm.state, c))
        #Same error node already exists in FSM
        if msg != 0:
            fsm.modify_error((fsm.state, c), msg + " OR " + args.errmsg)
            break

        fsm.trigger(c)
    fsm.state = "_start"

    fsm.draw(args)
    with open(globs.fdir + "/output/" + args.trid + ".method1", 'wb') as fp:
        pickle.dump(fsm, fp)

def check(commands, args):
    with open(globs.fdir + "/output/" + args.trid + ".method1", 'rb') as fp:
        fsm = pickle.load(fp)
    
    last_unknown = False  #Indicator if last command was unknown
    for i, c in enumerate(commands):
        if c in fsm.get_triggers(fsm.state):
            last_unknown = False
            msg = fsm.check_error((fsm.state, c))
            if msg != 0:
                raise RuntimeError("Method 1: Error found - " + msg)
                return False
            fsm.trigger(c)
        #unknown state
        else:
            #Known command in wrong place --> error
            if c in fsm.known_commands and not last_unknown:
                raise RuntimeError("Method 1: Unknown command")
                return False

            #if last state is unknown its OK
            if (i + 1) == len(commands):
                break
            nc = commands[i + 1]
            for ns in fsm.get_triggers(fsm.state):
                if nc in fsm.get_triggers(ns):
                    fsm.trigger(ns)
                    break
            last_unknown = True
    return True
