from threading import Lock
import inspect
import csv

from pyretic.modules.hub import *
from pyretic.modules.mac_learner import *
from pyretic.lib.corelib import *
from pyretic.lib.std import *
from pyretic.lib.query import *
from pyretic.examples.event_listener import *
from pyretic.examples.routing import *
from pyretic.examples.statistics import *
from pyretic.examples.firewall_resource import *
from pyretic.examples.firewall_groups import *
from pyretic.examples.shared import *
from pyretic.core.runtime import virtual_field

usersFile = "pyretic/examples/configs/users.csv"
staticUsersFile = "pyretic/examples/configs/static-users.csv"
usernames = {}

class identityManagementArp(DynamicPolicy):
    def __init__(self, identityPolicy):
        self.debug = 0
        self.identityPolicy = identityPolicy
        self.network = None
        self.topology = None
        self.lock = Lock()
        self.macList = {}
        self.ipList = {}
        self.locationList = {}
        self.portsList2 = []
        self.backet = packets()
        self.backet.register_callback(self.handle_arp)
        self.socketBuffer = []
        self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        try:
            self.socket.connect("/tmp/iricol")
        except socket.error:
            print "WARNING: identityManagementArp > __init__ > connect(): nie je mozne sa pripojit na socket"
        self.staticEntries = () #(("aa:aa:aa:aa:aa:a1", "4.4.4.1", "aa:aa:aa:aa:aa:aa-1"))
        self.static_entries()
        super(identityManagementArp,self).__init__(true)
        self.update_policy()
    
    def static_entries(self):
        for pair in self.staticEntries:
            self.send_iri("BEGIN", pair[0], pair[1], pair[2])
    
    def time_now(self):
        unixTime = str(time.time())
        parts = unixTime.split('.')
        if len(parts[1]) == 1:
            parts[1] = parts[1] + "0"
        return str(int(time.mktime(time.localtime()))) + "." + parts[1]
    
    def clear_iri_message_buffer(self):
        if len(self.socketBuffer) > 0:
            self.socket.connect("/tmp/iricol")
            
            while len(self.socketBuffer) > 0:
                backup = self.socketBuffer[0]
                self.socket.sendall(backup)
                with self.lock:
                    del self.socketBuffer[0]
            
            if self.debug:
                print "identityManagementArp > send_iri > sendall(): buffer je vyprazdneny"
    
    def send_iri(self, type, mac, ip, location):
        description = "error"
        if type == "BEGIN":
            description = "New IP address has been detected"
        elif type == "CONTINUE":
            description = "User confirms, that he is still using IP address"
        elif type == "END":
            description = "IP address is no longer connected"
        
        message = "('ARP', "+str(self.time_now())+", '"+type+"', '"+description+"', [('MAC', '"+str(mac)+"'), ('IPv4', '"+str(ip)+"'), ('SDN location', '"+location+"')])\n"
        
        if type != "CONTINUE":
            print message[:-1]
        
        try:
            self.clear_iri_message_buffer()
            self.socket.sendall(message)
        except socket.error:
            print "WARNING: identityManagementArp > send_iri > sendall(): nie je mozne odoslat spravu, uklada sa do bufferu -", self.socketBuffer
            with self.lock:
                self.socketBuffer.append(message)
    
    def set_network(self, network):
        with self.lock:
            if self.topology and (self.topology == network.topology):
                pass
            else:
                self.topology = network.topology
                self.network = network
                for event in self.topology.lastEvents:
                    if event["type"] == "port_join":
                        try:
                            if self.topology.node[event["switch"]]["ports"][event["port"]].linked_to == None: # ignorujem porty zapojene do switchu
                                self.portsList2.append({'switch': event["switch"], 'port': event["port"], 'mac': None});
                        except KeyError:
                            print "WARNING: identityManagementArp > set_network > port_join > KeyError:", event["switch"], event["port"], self.topology.node
                            
                    elif event["type"] == "port_part":
                        for port in self.portsList2:
                            if port["switch"] == event["switch"] and port["port"] == event["port"]:
                                self.portsList2.remove(port)
                                if port["mac"] != None:
                                    self.delete_by_mac(port["mac"])
                                
                    elif event["type"] == "link_add":
                        for port in self.portsList2:
                            if port["switch"] == event["switch1"] and port["port"] == event["port1"]:
                                self.portsList2.remove(port)
                                if port["mac"] != None:
                                    self.delete_by_mac(port["mac"])
                                
                            elif port["switch"] == event["switch2"] and port["port"] == event["port2"]:
                                self.portsList2.remove(port)
                                if port["mac"] != None:
                                    self.delete_by_mac(port["mac"])
                    
                    elif event["type"] == "switch_part":
                        for port in self.portsList2:
                            if port["switch"] == event["switch"]:
                                self.portsList2.remove(port)
                                if port["mac"] != None:
                                    self.delete_by_mac(port["mac"])
                        
                    
                self.update_policy()
    
    def delete(self, mac, ip, location):
        try:
            del self.macList[mac]
            del self.locationList[mac]
            del self.ipList[ip]
            
            self.send_iri("END", mac, ip, location)
        except KeyError:
            print "WARNING: identityManagementArp > delete > KeyError:", mac, ip, self.macList, self.locationList, self.ipList
    
    def delete_by_mac(self, mac):
        try:
            ip = self.macList[mac]
            location = self.locationList[mac]
            self.delete(mac, ip, location)
        except KeyError:
            print "WARNING: identityManagementArp > delete_by_mac > KeyError:", mac, self.macList, self.locationList
    
    def delete_by_ip(self, ip):
        try:
            mac = self.ipList[ip]
            location = self.locationList[mac]
            self.delete(mac, ip, location)
        except KeyError:
            print "WARNING: identityManagementArp > delete_by_ip > KeyError:", ip, self.ipList, self.locationList
    
    def create_location(self, switch, port):
        dpid = hex(switch)[2:].zfill(12)
        dpid = ":".join([dpid[i:i+2] for i in range(0, len(dpid), 2)])
        return dpid + "-" + str(port)
    
    def decode_location(self, location):
        switch, outport = location.split('-', 1)
        switch = int(switch.translate(None, ':'), 16)
        outport = int(outport)
        return switch, outport
    
    def add_mac_ip(self, mac, ip, switch, port):
        location = self.create_location(switch, port)
        
        self.macList[mac] = ip
        self.locationList[mac] = location
        self.ipList[ip] = mac
        self.send_iri("BEGIN", mac, ip, location)
        
        for portIndex in range(0, len(portsList)):
            if portsList[portIndex]["switch"] == switch and portsList[portIndex]["port"] == port:
                found = False
                for deviceIndex in range(0, len(portsList[portIndex]["devices"])):
                    if portsList[portIndex]["devices"][deviceIndex]["mac"] == mac:
                        found = True
                        if portsList[portIndex]["devices"][deviceIndex]["ip"] != ip:
                            portsList[portIndex]["devices"][deviceIndex]["ip"] = ip
                            self.update_policy()
                if found == False:
                    portsList[portIndex]["devices"].append({'mac': mac, 'ip': ip, 'username': 'default', 'group': 'default'})
                    self.update_policy()
                return
    
    def save_mac_ip_from_arp_packet(self, pkt):
        if pkt['srcmac'] not in self.macList:
            if pkt['srcip'] not in self.ipList:
                self.add_mac_ip(pkt['srcmac'], pkt['srcip'], pkt['switch'], pkt['inport'])
            else:
                self.delete_by_ip(pkt['srcip'])
                self.add_mac_ip(pkt['srcmac'], pkt['srcip'], pkt['switch'], pkt['inport'])
        else:
            if pkt['srcip'] not in self.ipList:
                self.delete_by_mac(pkt['srcmac'])
                self.add_mac_ip(pkt['srcmac'], pkt['srcip'], pkt['switch'], pkt['inport'])
            else:
                if self.macList[pkt['srcmac']] == pkt['srcip']:
                    self.send_iri("CONTINUE", pkt['srcmac'], pkt['srcip'], self.locationList[pkt['srcmac']])
                else:
                    self.delete_by_ip(pkt['srcip'])
                    self.delete_by_mac(pkt['srcmac'])
                    self.add_mac_ip(pkt['srcmac'], pkt['srcip'], pkt['switch'], pkt['inport'])
    
    def handle_arp(self, pkt):
        self.save_mac_ip_from_arp_packet(pkt)
        
        if pkt['protocol'] == 1: # ARP request
            if self.debug:
                print "ARP REQUEST from", self.create_location(pkt['switch'], pkt['inport']) ,"about", pkt['dstip']
            # forward request out of all egress ports
            for loc in self.network.topology.egress_locations() - {Location(pkt['switch'], pkt['inport'])}:
                switch  = loc.switch
                outport = loc.port_no
                pkt = pkt.modify(switch=switch)
                pkt = pkt.modify(outport=outport)
                try:
                    self.network.inject_packet(pkt)
                except Exception:
                    print "WARNING: identityManagementArp > handle_arp > ARP request > inject_packet:", pkt, self.network
        
        elif pkt['protocol'] == 2: # ARP response
            if self.debug:
                print "ARP RESPONSE from", self.create_location(pkt['switch'], pkt['inport']), "about", pkt['srcip']
            
            # forward request out of port with destination MAC address
            if pkt['dstmac'] in self.locationList:
                location = self.locationList[pkt['dstmac']]
                switch, outport = self.decode_location(location)
                pkt = pkt.modify(switch=switch)
                pkt = pkt.modify(outport=outport)
                try:
                    self.network.inject_packet(pkt)
                except Exception:
                    print "WARNING: identityManagementArp > handle_arp > ARP response > inject_packet:", pkt, self.network
            
    
    def update_policy(self):
        if len(self.portsList2) == 0:
            self.policy = drop
        else:
            self.policy = match(ethtype=ARP_TYPE, switch=self.portsList2[0]["switch"], inport=self.portsList2[0]["port"]) >> identity
            for i in range(1, len(self.portsList2)):
                self.policy = self.policy + match(ethtype=ARP_TYPE, switch=self.portsList2[i]["switch"], inport=self.portsList2[i]["port"]) >> identity
        
        self.policy = self.policy >> self.backet
        if self.debug:
            print "identityManagementArp POLICY", self.policy
        self.identityPolicy.update_policy()

    
class identityManagementMark(DynamicPolicy):
    def set_network(self, network):
        global portsList
        with self.lock:
            if self.topology and (self.topology == network.topology):
                pass
            else:
                self.topology = network.topology
                for event in self.topology.lastEvents:
                    if event["type"] == "port_join":
                        if self.topology.node[event["switch"]]["ports"][event["port"]].linked_to == None: # ignorujem porty zapojene do switchu
                            found = False
                            for port in portsList:
                                if port["switch"] == event["switch"] and port["port"] == event["port"]:
                                    found = True
                                    break
                            if found == False:
                                portsList.append({'switch': int(event["switch"]), 'port': int(event["port"]), 'devices': []});
                    elif event["type"] == "port_part":
                        for port in portsList:
                            if port["switch"] == event["switch"] and port["port"] == event["port"]:
                                portsList.remove(port)
                    elif event["type"] == "link_add":
                        for port in portsList:
                            if port["switch"] == event["switch1"] and port["port"] == event["port1"]:
                                portsList.remove(port)
                            elif port["switch"] == event["switch2"] and port["port"] == event["port2"]:
                                portsList.remove(port)
                self.update_policy()
    
    def event_handler(self, event):
        global portsList
        username = event["username"]
        mac = EthAddr(event['mac']);
        ip = IPAddr(event['ip']);
        
        switch, port = event["location"].split('-', 1)
        switch = int(switch.translate(None, ':'), 16)
        port = int(port)
        
        updatePolicy = False
        
        if event["type"] == "BEGIN":
            group = usernames[username] if username in usernames else "default"
            portFound = False
            for portIndex in range(0, len(portsList)):
                if portsList[portIndex]["switch"] == switch and portsList[portIndex]["port"] == port:
                    portFound = True
                    deviceFound = False
                    for deviceIndex in range(0, len(portsList[portIndex]["devices"])):
                        if portsList[portIndex]["devices"][deviceIndex]["mac"] == mac:
                            deviceFound = True
                            
                            if portsList[portIndex]["devices"][deviceIndex]["username"] != username or portsList[portIndex]["devices"][deviceIndex]["group"] != group:
                                updatePolicy = True
                                
                                portsList[portIndex]["devices"][deviceIndex]["username"] = username
                                portsList[portIndex]["devices"][deviceIndex]["group"] = group
                            
                            break
                    
                    if deviceFound == False:
                        portsList[portIndex]["devices"].append({'mac': mac, 'ip': ip, 'username': username, 'group': group})
                        updatePolicy = True
                    
                    break
                    
            if portFound == False:
                portsList.append({'switch': switch, 'port': port, 'devices': [{'mac': mac, 'ip': ip, 'username': username, 'group': group}]})
                updatePolicy = True
            
        elif event["type"] == "END":
            for portIndex in range(0, len(portsList)):
                if portsList[portIndex]["switch"] == switch and portsList[portIndex]["port"] == port:
                    for deviceIndex in range(0, len(portsList[portIndex]["devices"])):
                        if portsList[portIndex]["devices"][deviceIndex]["mac"] == mac:
                            portsList[portIndex]["devices"][deviceIndex]["username"] = 'default'
                            portsList[portIndex]["devices"][deviceIndex]["group"] = 'default'
                            updatePolicy = True
        
        if updatePolicy:
            self.update_policy()
    
    def __init__(self, handlers):
        import re
        self.debug = 0
        self.topology = None
        self.lock = Lock()
        self.handlers = handlers
        
        groups = ["default"]
        
        self.sims = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.sims.connect((sims_ip, int(sims_port)))
        
        with open(usersFile, 'rb') as f:
            reader = csv.reader(f, delimiter=';')
            line = 0
            for row in reader:
                line = line + 1
                if line == 1:
                    continue
                usernames[row[0]] = row[1]
                if re.match(r'^(?:[0-9]{1,3}\.){3}[0-9]{1,3}$', str(row[0])) != None: # staticka IP adresa
                    print "monitoring identifier", "['" + str(row[0]) + "']"
                    self.sims.sendall(str(row[0]) + "\n")
                else:
                    for nidType in nidTypes:
                        print "monitoring identifier", "['" + nidType + ":" + str(row[0]) + "']"
                        self.sims.sendall(nidType + ":" + str(row[0]) + "\n")
                if row[1] not in groups:
                    groups.append(row[1])
        
        self.sims.sendall("load_all_data\n")
        
        virtual_field(name="group", values=groups, type="string")
        events = eventListener(self.event_handler, self.sims) # vytvorenie noveho vlakna pre prijem udalosti
        super(identityManagementMark,self).__init__(true)
        self.update_policy()
    
    def update_policy (self):
        self.policy = identity;
        for port in portsList:
            devicesAction = modify(group="default")
            for device in port["devices"]:
                devicesAction = if_(match(srcmac=device["mac"]), modify(group=device["group"]), devicesAction)
            self.policy = if_(match(switch=port["switch"], inport=port["port"]), devicesAction, self.policy)
            
        for handler in self.handlers:
            handler.update_policy()
        
        if self.debug:
            print "identitymanagement POLICY", self.policy

def identitymanagement(handlers):
    mark = identityManagementMark(handlers)
    arp = identityManagementArp(mark)
    return arp + mark

def main():
    return identitymanagement([])

