# Class describing NID Graph
#
# Copyright (C) 2013 Radek Hranicky
# 
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import copy

class NIDGraph(object):
    def __init__(self):
        self.nodes = set()
        self.persistentNodes = set()
        self.connections = set()
        self.nonUniqueNIDs = set()
        self.liid_specific_nodes = dict()
        self.liid_specific_connections = dict()
        
    def setNonUniqueNIDs(self, nuNIDs):
        self.nonUniqueNIDs = set(nuNIDs)
    
    def nodeExists(self, NID, LIID=""):
        """ Returns if a node with given NID exists in the graph"""
        if LIID == "":
            # Common graph
            if NID in self.nodes:
                return True
            else:
                return False
        else:
            # LIID-specific graph
            if LIID not in self.liid_specific_nodes:
                return False
            if NID in self.liid_specific_nodes[LIID]:
                return True
            else:
                return False
            
    def addNode(self, NID, LIID=""):
        """ Adds a new node into the graph """
        if LIID == "":
            # Common graph
            if not self.nodeExists(NID):
                self.nodes.add(NID)
        else:
            # LIID-specific graph
            if LIID not in self.liid_specific_nodes:
                self.liid_specific_nodes[LIID] = set()
            if not self.nodeExists(NID, LIID):
                self.liid_specific_nodes[LIID].add(NID)
                
    def addPersistentNode(self, NID, LIID=""):
        """ Adds a new node into the graph """
        self.addNode(NID, LIID)
        if not NID in self.persistentNodes:
            self.persistentNodes.add(NID)
            
    def delNode(self, NID, LIID=""):
        """ Deletes a node with given NID """
        if NID in self.persistentNodes:            
            return # Marked as persistent - could not be deleted
        if LIID == "":
            # Common graph
            # Remove NID from graph
            self.nodes.discard(NID)                
            # Remove all remaining connections to it
            self.connections = {con for con in self.connections if NID not in con[0:2]}
        else:
            # LIID-specific graph
            if LIID not in self.liid_specific_nodes:
                return
            self.liid_specific_nodes[LIID].discard(NID)    
            if LIID not in self.liid_specific_connections:
                return            
            self.liid_specific_connections[LIID] = {con for con
                in self.liid_specific_connections[LIID] if NID not in con[0:2]}
                    
    def connectionExists(self, NID1, NID2, moduleName, LIID=""):
        """ Returns if a connection between two nodes exists """
        if LIID == "":
            # Common graph
            if ((NID1, NID2, moduleName) in self.connections or
                (NID2, NID1, moduleName) in self.connections
            ):
                return True
            else:
                return False
        else:
            # LIID-specific graph
            if LIID not in self.liid_specific_connections:
                return False
            if ((NID1, NID2, moduleName) in self.liid_specific_connections[LIID] or
                (NID2, NID1, moduleName) in self.liid_specific_connections[LIID]
            ):
                return True
            else:
                return False
    
    def update_addConnection(self, NID1, NID2, moduleName, LIID=""):
        """ Adds a new connection between NID1 and NID2
            Adds NID1 if it does not exist
            Adds NID2 if it does not exist
            Returns True if changes were made
        """
        changes = False
        if LIID == "":
            # Add a connection to common graph
            if not self.nodeExists(NID1):
                self.addNode(NID1)
                changes = True
            if not self.nodeExists(NID2):
                self.addNode(NID2)
                changes = True
            if self.addConnection(NID1, NID2, moduleName):
                changes = True
        else:
            # Add a connection to LIID-specific graph
            if not self.nodeExists(NID1, LIID):
                self.addNode(NID1, LIID)
            if not self.nodeExists(NID2, LIID):
                self.addNode(NID2, LIID)
            self.addConnection(NID1, NID2, moduleName, LIID)
        return changes
        
    def update_delConnection(self, NID1, NID2, moduleName, LIID=""):
        """ Deletes a connection between NID1 and NID12
            Deletes NID1 if there is no other connection to it
            Deletes NID2 if there is no other connection to it
            Returns True if changes were made in common graph
        """
        changes = False
        if LIID == "":
            # Delete a connection in common graph
            if self.delConnection(NID1, NID2, moduleName):
                changes = True
            if self.isOrphan(NID1):
                self.delNode(NID1)
                changes = True
            if self.isOrphan(NID2):
                self.delNode(NID2)
                changes = True
        else:
            # Delete a connection in LIID-specific graph
            self.delConnection(NID1, NID2, moduleName, LIID)
            if self.isOrphan(NID1, LIID):
                self.delNode(NID1, LIID)
            if self.isOrphan(NID2, LIID):
                self.delNode(NID2, LIID)
        return changes
        
                                      
    def addConnection(self, NID1, NID2, moduleName, LIID=""):
        """ Adds a new connection between two nodes
            (Does nothing if the connection already exists)
            Returns True if changes were made in common graph
        """
        if LIID == "":
            # Add to common graph
            if not self.connectionExists(NID1, NID2, moduleName):
                con = (NID1, NID2, moduleName)
                self.connections.add(con)
                return True
            return False
        else:
            # Add to LIID-specific graph
            if LIID not in self.liid_specific_connections:
                self.liid_specific_connections[LIID] = set()
            if not self.connectionExists(NID1, NID2, moduleName, LIID):
                con = (NID1, NID2, moduleName)
                self.liid_specific_connections[LIID].add(con)
            return False
    
    def isOrphan(self, NID, LIID=""):
        """ Returns true if there is NO connection to this node
            (Returns false if there is a connection to his node)
        """
        if LIID == "":
            conset = self.connections
        else:
            if LIID not in self.liid_specific_connections:
                return False
            conset = self.liid_specific_connections[LIID]
        for con in conset:
            if NID in con[0:2]:
                return False
        return True
            
    def delConnection(self, NID1, NID2, moduleName, LIID=""):
        """ Deletes a connection between two nodes
            Returns True if changes were made
        """
        changes = False
        if LIID == "":
            # Delete a connection in common graph
            if self.connectionExists(NID1, NID2, moduleName):
                self.connections.discard((NID1, NID2, moduleName))
                changes = True
            if self.connectionExists(NID2, NID1, moduleName):
                self.connections.discard((NID2, NID1, moduleName))
                changes = True
        else:
            # Add to LIID-specific graph
            if self.connectionExists(NID1, NID2, moduleName, LIID):
                self.connections.discard((NID1, NID2, moduleName, LIID))
            if self.connectionExists(NID2, NID1, moduleName, LIID):
                self.connections.discard((NID2, NID1, moduleName, LIID))
        return changes
            
    def debug_printGraph(self):
        """ Prints the content of NID graph """
        print("============= DEBUG: Printing NID Graph ===============")
        print("------ NODES ------")
        for node in self.nodes:
            print(node)
        print("------ CONNECTIONS ------")
        for con in self.connections:
            print(con)
           
            
    def deleteLIIDSpecificGraph(self, LIID):
        try:
            del self.liid_specific_nodes[LIID]
            del self.liid_specific_connections[LIID]
        except:
            pass
            
    # =================================================================
    # Methods for NID lookup
              
    def getAllNeighbors(self, baseNID, nodeset, conset):
        """ Gets all directly connected NIDs to specific base NID """
        if not baseNID in nodeset:
            return set()
        cons = {con for con in conset if baseNID in con[0:2]}
        neighbors = set()
        for con in cons:
            for node in con[0:2]:
                if node != baseNID and node not in self.nonUniqueNIDs:
                    neighbors.add(node)
        return neighbors
                
    def getNeighborsOfType(self, baseNID, neighborType, nodeset, conset):
        """ Gets all directly connected NIDs of specific type to given base NID """
        if not self.nodeExists(baseNID):
            return set()
        cons = {con for con in conset if baseNID in con[0:2]}
        neighbors = set()
        for con in cons:
            for node in con[0:2]:
                if (node != baseNID and node.getIdentifierType() == neighborType and
                    node not in self.nonUniqueNIDs
                ):
                    neighbors.add(node)
        return neighbors
    
    def getAccessableNids(self, baseNID, throughA, throughB, throughC, throughD, ignoreServerB, nodeset, conset):
        """ Gets all NIDs which are accessable from given base NID

        baseNID base NID
        throughA include and go through identifiers of type "A"
        throughB include and go through identifiers of type "B"
        throughC include and go through identifiers of type "C"
        throughD include and go through identifiers of type "D"
        ignoreServerB ignore identifiers of type "B" which represent a server IP of a TCP connection
        """
        nids = {baseNID}
        newNidsExist = True
        while newNidsExist:
            newNids = set()
            for nid in nids:                   
                neighbors = self.getAllNeighbors(nid, nodeset, conset)
                for neighbor in neighbors:
                    if neighbor not in newNids and neighbor not in nids:
                        if ((ignoreServerB and self.isServer(neighbor, nodeset, conset)) or
                            (neighbor.getIdentifierType() == "A" and not throughA) or
                            (neighbor.getIdentifierType() == "B" and not throughB) or
                            (neighbor.getIdentifierType() == "C" and not throughC) or
                            (neighbor.getIdentifierType() == "D" and not throughD) or
                            neighbor in self.nonUniqueNIDs
                        ):
                            continue
                        newNids.add(neighbor)
            if not newNids:
                newNidsExist = False
            nids.update(newNids)
        return nids
        
    def isServer(self, nid, nodeset, conset):
        """ Returns if given NID is a server of a TCP connection """
        if nid.getIdentifierType() != "B":
            return False
        TCPcons = self.getNeighborsOfType(nid, "A", nodeset, conset)
        for con in TCPcons:
            if con.getType() == "TCP" and str(nid) == str(con.getServerIP()):
                return True
        return False
    
    def nidLookup (self, firstBaseNID, iceptLevel, LIID=""):
        """ Gets all NIDs for given base NID (from intercept table) and given level of interception """
        
        # Make nodeset from default nodeset and LIID-specific-nodeset
        nodeset = copy.deepcopy(self.nodes)
        if LIID != "" and LIID in self.liid_specific_nodes:
            nodeset = nodeset.union(self.liid_specific_nodes[LIID])
        # Make conset from default conset and LIID-specific-conset
        conset =  copy.deepcopy(self.connections)
        if LIID in self.liid_specific_connections:
            conset = conset.union(self.liid_specific_connections[LIID])
        
        # Handle the situation when baseNID is IP address range, etc.
        basenids = set()
        
        for nid in nodeset:
            if nid.getType() == firstBaseNID.getType() and nid in firstBaseNID:
                basenids.add(nid)
        
        if not basenids:
            # What if base NID does not exist?
            if firstBaseNID.isNIDCC():
                # Interception to static NIDCC - not neccessary to be in graph.
                return [firstBaseNID]
            else:
                # Nothing will be taken into list
                return []
        
        nids = set()
        for baseNID in basenids:
            if iceptLevel == 1:                                             # Interception level 1
                nids.update(self.getNidsLevel1(baseNID, nodeset, conset))   # Perform NID lookup
            elif iceptLevel == 2:                                           # Incerception level 2 
                nids.update(self.getNidsLevel2(baseNID, nodeset, conset))   # Perform NID lookup
            elif iceptLevel == 3:                                           # Interception level 3
                nids.update(self.getNidsLevel3(baseNID, nodeset, conset))   # Perform NID lookup
            if not baseNID in nids:                                         # In case baseNID is not in NID list
                nids.add(baseNID)                                           # add it to NID list.
        return nids
            
    def getNidsLevel1(self, baseNID, nodeset, conset):
        """ Gets all NIDs for given base NID and level 1 of interception """
        if baseNID.getIdentifierType() == "A":                             # * Identifier type is "A" =>
            nids = self.getNeighborsOfType(baseNID, "A", nodeset, conset)  # Include all directly connected neighbors of type "A"
        elif baseNID.getIdentifierType() == "B":                           # * Identifier type is "B" =>
            nids = self.getNeighborsOfType(baseNID, "A", nodeset, conset)  # Include all directly connected neighbors of type "A
        elif baseNID.getIdentifierType() == "C":                           # * Identifier type is "C" =>
            nids = self.getNeighborsOfType(baseNID, "B", nodeset, conset)  # Include all directly connected neighbors of type "B"
            newNids = set()
            for nid in nids:
                newNids.update(self.getNeighborsOfType(nid, "A", nodeset, conset))  # and all neighbors of type "A" connected to them
            nids.update(newNids)
        elif baseNID.getIdentifierType() == "D":                                    # * Identifier type is "D" =>
            nids = self.getNeighborsOfType(baseNID, "C", nodeset, conset)           # Include all directly connected neighbors of type "C"
            newNids = set()
            for nid in nids:
                newNids.update(self.getNeighborsOfType(nid, "B", nodeset, conset))  # all neighbors of type "B" connected to them
            nids.update(newNids)
            newNids = set()
            for nid in nids:
                newNids.update(self.getNeighborsOfType(nid, "A", nodeset, conset))  # and all neighbors of type "A" connected to them
            nids.update(newNids)
        return nids

    def getNidsLevel2(self, baseNID, nodeset, conset):
        """ Gets all NIDs for given base NID and level 2 of interception """
        if baseNID.getIdentifierType() == "A":
            # * Identifier type is "A" =>
            # All nodes accessable through type ABC except B connected to A as serrver
            nids = self.getAccessableNids(baseNID, True, True, True, False, True, nodeset, conset)
        else:
            # * Identifier type is "B", "C" or "D" =>
            # All nodes accessable through type BC
            nids = self.getAccessableNids(baseNID, False, True, True, False, False, nodeset, conset)
            newNids = set()
            # All directly connected nodes of type A
            for nid in nids:
                newNids.update(self.getNeighborsOfType(nid, "A", nodeset, conset))
            nids.update(newNids)
        return nids
        
    def getNidsLevel3(self, baseNID, nodeset, conset):
        """ Gets all NIDs for given base NID and level 3 of interception """
        if baseNID.getIdentifierType() == "A":
            # * Identifier type is "A" =>
            # All nodes accessable through type ABCD except B connected to A as server
            nids = self.getAccessableNids(baseNID, True, True, True, True, True, nodeset, conset)
        else:
            # * Identifier type is "B", "C" or "D" =>
            # All nodes accessable through type BCD
            nids = self.getAccessableNids(baseNID, False, True, True, True, False, nodeset, conset)
            newNids = set()
            # All directly connected nodes of type A
            for nid in nids:
                newNids.update(self.getNeighborsOfType(nid, "A", nodeset, conset))
            nids.update(newNids)
        return list(set(nids))
        
    def createNIDTable(self):
        """ Create NID table from NID graph """
        nidTable = ""
        for con in self.connections:
            nidTable += str(con[2]) + '\t' + str(con[0]) + '\t' + str(con[1]) + '\n'
        return nidTable
        
    def getAllCommonNodes(self):
        return self.nodes
        
    def getAllLIIDSpecificNodes(self, LIID):
        return self.liid_specific_nodes[LIID]

    def findConnections_str(self, NIDs):
        cons = self.connections
        for LIID in self.liid_specific_connections:
            cons.update(self.liid_specific_connections[LIID])            
        found_cons = set()
        for con in cons:
            nid_a, nid_b, modulename = con
            if nid_a in NIDs and nid_b in NIDs and nid_a != nid_b:
                con = frozenset([nid_a, nid_b])
                found_cons.add(con)                
        # Convert to string representation
        cons_str = list()
        for found_con in found_cons:
            fc = list(found_con)
            con_str = (str(fc[0]), str(fc[1]))
            cons_str.append(con_str)
        return cons_str
            
