from threading import Lock
import inspect
import csv


from pyretic.examples.shared import *
from pyretic.lib.corelib import *
from pyretic.lib.std import *
from pyretic.lib.query import *
from pyretic.core.runtime import virtual_field

resourcesFile = "pyretic/examples/configs/resources.csv"
resourceRulesFile = "pyretic/examples/configs/routing.csv"
infiniteMetric = 100000000000

class routing(DynamicPolicy):
    def __init__(self):
        self.debug = 0
        self.network = None
        self.topology = None
        self.lock = Lock()
        self.rules = {}
        self.groups = []
        self.loadRules()
        self.newPolicy = drop
        self.dhcpMac = {}
        self.ipMapDb = {}
        self.dhcpIPList = dhcpServers
        self.backet = packets()
        self.backet.register_callback(self.handle_packet)
        super(routing,self).__init__(true)
        self.update_policy()

    def set_cost(self, switch1, switch2):
        for edgetype in self.topology[switch1][switch2]['type']:
            if edgetype[-2:] == "FD":
                part = edgetype.split('_')
                speed = part[1]
                mbps = int(speed[:-2])
                unit = speed[-2:]
                if unit == "GB":
                    mbps *= 1000
                cost = 1000000000 / mbps
                if cost < 1:
                    cost = 1
                self.topology[switch1][switch2]['weight'] = cost
                return
    
    def condition_eval(self, condition):
        import re
        
        condition = ' '.join(condition.split())
        if condition == "never":
            return False
        
        matchObj = re.search( r'([0-9]+-[0-9]+)', condition, re.M|re.I)
        while matchObj != None:
            link = matchObj.group()
            switch1 = int(link.split('-')[0])
            switch2 = int(link.split('-')[1])
            
            if (switch1, switch2) in self.topology.edges() or (switch2, switch1) in self.topology.edges():
                linkUp = False
            else:
                linkUp = True
            
            parts = condition.split(link, 1)
            condition = parts[0] + str(linkUp) + parts[1]
            
            matchObj = re.search( r'([0-9]+-[0-9]+)', condition, re.M|re.I)
        
        return eval(condition)
    
    def update_conditions(self):
        for link in self.rules:
            switch1 = int(link.split('-')[0])
            switch2 = int(link.split('-')[1])
            if switch1 in self.topology.nodes() and switch2 in self.topology.nodes():
                if switch2 in self.topology[switch1]:
                    self.set_cost(switch1, switch2)
                    self.set_cost(switch2, switch1)
                    for group in self.rules[link]:
                        if self.condition_eval(self.rules[link][group]) == False:
                            self.topology[switch1][switch2][group] = infiniteMetric
                        else:
                            self.topology[switch1][switch2][group] = self.topology[switch1][switch2]['weight']
    
    def set_network(self, network):
        with self.lock:
            if self.topology and (self.topology == network.topology):
                pass
            else:
                self.topology = network.topology
                self.network = network
                
                updateConditions = False
                for event in self.topology.lastEvents:
                    if event["type"] == "link_add":
                        updateConditions = True
                        self.set_cost(event["switch1"], event["switch2"])
                        self.set_cost(event["switch2"], event["switch1"])
                    elif event["type"] == "link_part":
                        updateConditions = True
                
                if updateConditions == True:
                    self.update_conditions()
                    self.update_policy()
    
    def loadRulesGroup(self, row, group):
        link = row[0]
        condition = row[2]
        
        if link not in self.rules:
            self.rules[link] = {}
        
        self.rules[link][group] = condition
    
    def loadRules(self):
        with open(resourceRulesFile, 'rb') as f:
            reader = csv.reader(f, delimiter=';')
            line = 0
            for row in reader:
                line = line + 1
                if line == 1:
                    continue
                
                for group in row[1].split(","):
                    if group not in self.groups:
                        self.groups.append(group)
                    self.loadRulesGroup(row, group)
    
    def createPolicyForGroup(self, policyForGroup):
        if self.topology != None:
            pass
        else:
            return
        
        for port in portsList:
            for switch in self.topology.nodes():
                try:
                    if policyForGroup == None:
                        length = nx.shortest_path_length(self.topology, switch, port["switch"], 'weight')
                    else:
                        length = nx.shortest_path_length(self.topology, switch, port["switch"], policyForGroup)
                except:
                    continue
                
                if length >= infiniteMetric:
                    for device in port["devices"]:
                        if policyForGroup == None:
                            self.newPolicy = if_(match(switch=switch,dstip=IPAddr(device["ip"]),ethtype=IP_TYPE), drop, self.newPolicy)
                        else:
                            self.newPolicy = if_(match(switch=switch,dstip=IPAddr(device["ip"]),ethtype=IP_TYPE,group=policyForGroup), drop, self.newPolicy)
                    continue
                
                if policyForGroup == None:
                    path = nx.shortest_path(self.topology, switch, port["switch"], 'weight')
                else:
                    path = nx.shortest_path(self.topology, switch, port["switch"], policyForGroup)
                
                # cielovy switch je totozny zo zdrojovym
                if len(path) == 1:
                    output = port["port"]
                else:
                    output = self.topology[switch][path[1]][switch]
                
                # vypocitana najkratsia cesta plati pre vsetky koncove stanice na cielovom switchi
                for device in port["devices"]:
                    if policyForGroup == None:
                        self.newPolicy = if_(match(switch=switch,dstip=IPAddr(device["ip"]),ethtype=IP_TYPE), fwd(output), self.newPolicy)
                    else:
                        self.newPolicy = if_(match(switch=switch,dstip=IPAddr(device["ip"]),ethtype=IP_TYPE,group=policyForGroup), fwd(output), self.newPolicy)
        
    
    def handle_arp(self, pkt):
        switch = pkt['switch']
        port = pkt['inport']
        ip = pkt['srcip']
        
        if switch not in self.ipMapDb:
                self.ipMapDb[switch] = {}
        
        newIp = True
        try:
            oldIp = self.ipMapDb[switch][port]
            if oldIp == ip:
                newIp = False
        except KeyError:
            pass
        
        self.ipMapDb[switch][port] = ip
        if newIp:
            self.update_policy()
    
    def dhcp_send_request(self, pkt, serverIP):
        for switch in self.ipMapDb:
            for port in self.ipMapDb[switch]:
                if self.ipMapDb[switch][port] == IPAddr(serverIP):
                    pkt = pkt.modify(switch=switch)
                    pkt = pkt.modify(outport=port)
                    try:
                        self.network.inject_packet(pkt)
                    except Exception:
                        print "WARNING: arp_routing > dhcp_send_request > inject_packet:", pkt, self.network
                    return
        
        for port in portsList:
            for device in port["devices"]:
                if device["ip"] == IPAddr(serverIP):
                    pkt = pkt.modify(switch=port["switch"])
                    pkt = pkt.modify(outport=port["port"])
                    try:
                        self.network.inject_packet(pkt)
                    except Exception:
                        print "WARNING: arp_routing > dhcp_send_request > inject_packet:", pkt, self.network
                    return
    
    def handle_dhcp(self, pkt):
        if pkt['srcport'] == 68: # pakety klient->server
            
            if Location(pkt['switch'], pkt['inport']) in self.topology.egress_locations():
                self.dhcpMac[pkt['srcmac']] = (pkt['switch'], pkt['inport'])
                
                if pkt['dstip'] == IPAddr("255.255.255.255"):
                    for dhcpIp in self.dhcpIPList:
                        self.dhcp_send_request(pkt, dhcpIp)
                else:
                    self.dhcp_send_request(pkt, pkt['dstip'])
            
        else: # pakety server->klient
            try:
                location = self.dhcpMac[pkt['dstmac']]
            except KeyError:
                print "WARNING: arp_routing > handle_dhcp > server->klient > KeyError:", pkt['dstmac'], self.dhcpMac
                return
            
            pkt = pkt.modify(switch=location[0])
            pkt = pkt.modify(outport=location[1])
            
            try:
                self.network.inject_packet(pkt)
            except Exception:
                print "WARNING: arp_routing > handle_dhcp > server->klient > inject_packet:", pkt, self.network
    
    def handle_packet(self, pkt):
        if pkt['ethtype'] == ARP_TYPE:
            self.handle_arp(pkt)
        else:
            self.handle_dhcp(pkt)
        
    
    def update_policy (self):
        self.newPolicy = if_(match(ethtype=ARP_TYPE), self.backet, drop)
        self.newPolicy = if_(match(protocol=17,srcport=68,dstport=67,ethtype=IP_TYPE), self.backet, self.newPolicy) #DHCP klient->server
        self.newPolicy = if_(match(protocol=17,srcport=67,dstport=68,ethtype=IP_TYPE), self.backet, self.newPolicy) #DHCP server->klient
        
        self.createPolicyForGroup(None)
        
        for group in self.groups:
            self.createPolicyForGroup(group)
        
        self.policy = self.newPolicy
        
        if self.debug:
            print "routing POLICY", self.policy
                    
def main():
    return firewall_groups()

