#!/usr/bin/env python3
#
# IRI-CORE of light version of LI system
#
# Copyright (C) 2013 Radek Hranický
# 
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import os
import socket
import sys
import threading
import time
import string
import copy
from threading import *
from collections import deque

from modules.sockets.li_socket import LISocket
from modules.sockets.li import acceptTCPLIConnection
from modules.sockets.li_socket_manager import LISocketManager
from modules.sockets.li import acceptConnection, ptptAsClient, acceptTCPLIConnection, \
    acceptTCPLIConnectionSimple
from modules.shared.ini1aintercept import INI1AIntercept
from modules.shared.cid import CID
from modules.shared.nid import NID
from modules.shared.nid import CreateNID
from modules.shared.nid import CreateNIDType
from modules.shared.nid import NIDIP
from modules.tools.config import parseServerTcpIfcConfig
from modules.tools.config import parseClientTcpIfcConfig
from modules.tools.config import parseIRICoreNIDConfig
from modules.tools.time import periodToStructTime
from modules.shared.time_priority_queue import TimePriorityQueue
from modules.iriiif.dhcp_status import *
from subprocess import call
import subprocess as sub
import modules.tools.log as log

from modules.iriiif.intercept_table import *
from modules.iriiif.cin_table import *
from modules.iriiif.nid_graph import *

IRIP_KNOWN_TOPOLOGY_FILE = "irip.content"

def writeIRIPTable(probe, content):
    log.info("IRI Probe known network topology file overwriten")
    with open(IRIP_KNOWN_TOPOLOGY_FILE, "w") as f:
        f.write(content)

def iricore_init():
    # Split for testing purposes
    global QUEUE_SIZE
    global interceptQueue
    global interceptTable
    global cinTable
    global nidGraph
    global run
    global logger
    
    QUEUE_SIZE = 0                            # Infinite
    interceptQueue = TimePriorityQueue(None)  # queue for waiting interceptions
    interceptTable = InterceptTable()         # intercept table
    cinTable = CINTable()                     # CIN table
    nidGraph = NIDGraph()                     # NID graph
    writeIRIPTable(None, "")
    run = True;
    logger = log.createFileLogger("IRICore_logger", "iri-core.log")

iricore_init()

def main(argv):
    # Setup logging
    log.setupLogging(argv[1:])
    global logger
    
    # Read configuration
    ini1a = parseClientTcpIfcConfig("ini1a", "iri-core.ini")
    ini2 = parseClientTcpIfcConfig("ini2", "iri-core.ini")
    iricore = parseServerTcpIfcConfig("iricore", "iri-core.ini")
    nidGraph.setNonUniqueNIDs(parseIRICoreNIDConfig("iri-core","iri-core.ini"))
    
    # Initialize socket manager
    global sm
    sm = LISocketManager(
        (),
        ("iricorewq",),
        (
            ("iricore",) + iricore,
            ("ini2" ,) + ini2,
        ),
        (
            ("ini1a",) + ini1a,
        ),
    )
    
    # Initialize TCP socket to communicate with SDM modules
    #sdmsocket = socketserver.TCPServer(irisdm, SDMIRIHandler)
    #sdmsocket = socketserver.TCPServer(("localhost", 9999), SDMIRIHandler)
    #sdmsocket.serve_forever()
    
    # Prepare and run the intercept queue thread
    intqThread = threading.Thread(target = queueThread, args=(interceptQueue, "iricorewq"))
    intqThread.daemon = True
    intqThread.start()

    # Run socket manager
    sm.tryInterfaces(argv[0])
    logEvent("IRI-IIF Core started")
 
    # Main loop
    sm.mainLoop(globals())
    
    # Close sockets
    sm.closeSockets()    
    logEvent("IRI-IIF Core finished")
    
    # Stop the intercept queue thread
    global run
    run = False
    interceptQueue.push(time.mktime(time.localtime()), None)
    intqThread.join()

def queueThread(queue, socketName):
    """ Manages waiting interceptions """
    while True: # Infinite loop
        interception = queue.pop()
        if not run:
            break
        s = ptptAsClient(socketName)
        s.send(interception)
        
###########################################################################
# Socket handlers
      
def processRequestIRICOREWQ(s, sm):
    """ New interception is about to start

    s Server socket
    sm Socket manager
    """
    
    # Receive data
    conn = acceptConnection(s)
    interception = conn.blockingReceive()
    
    # Activate the interception
    activateIntercept(sm, interception)
    
    return True

def brokenSocketINI1A(s, sm):
    """ Handles situation when socket connection fails

    Returns if the IRI-IIF should continue
    """
    logEvent("INI1A: broken connection %s" % str(s.getpeername()), "error")
    return False

def brokenSocketINI2(s, sm):
    """ Handles situation when socket connection fails

    Returns if the IRI-IIF should continue
    """
    logEvent("INI2: broken connection %s" % str(s.getpeername()), "error")
    return True

def brokenSocketIRICORE(s, sm):
    """ Handles situation when socket connection fails

    Returns if the IRI-IIF should continue
    """
    logEvent("IRI-IIF: broken connection to IRI-probe %s" % str(s.getpeername()), "error")
    return False

def brokenSocketIRIINTERNAL(s, sm):
    """ Handles situation when socket connection fails

    Returns if the IRI-IIF should continue
    """
    logEvent("IRI-IIF: broken internal connection to IRI-probe %s" % str(s.getpeername()), "error")
    return True

def processMessageINI1A(msg, sm, s):
    """ Handler for processing messages received through INI1a

    msg Received message
    sm Socket manager
    s Socket that has received the message

    Returns if the IRI-IIF should continue
    """

    sm.send("ini1a", ("ack", msg))
    handlers = {"new_intercept": newIntercept, "delete_intercept": deleteIntercept}
    return handlers[msg[0]](sm, *msg[2:])
    return True


def processMessageINI2(msg, sm, s):
    """ Handler for processing messages received through INI2

    msg Received message
    sm Socket manager
    s Socket that has received the message

    Returns if the IRI-IIF should continue
    """

    return True


def processRequestINI2(s, sm):
    """ Process network packets from output interface

    s Server socket
    sm socket manager
    
    Returns if the IRI-IIF should continue
    """
    
    conn = acceptTCPLIConnection(s, sm, "ini2")
    logEvent("IRI-IIF: new consumer: %s" % conn.getpeername(), "info")
    return True

def processRequestIRICORE(s, sm):
    """ Process network packets from iricol interface

    s Server socket
    sm socket manager
    
    Returns if the IRI-IIF should continue
    """
    
    conn = acceptTCPLIConnection(s, sm, "iriinternal")
    logEvent("IRI-IIF: new IRI-probe: %s" % conn.getpeername(), "info")
    return True

def processMessageIRIINTERNAL(msg, sm, s):
    if len(msg) != 2:
       log.error("Invalid message received: %s" % msg)
       return True

    iri_probe_name, msg = msg
    log.info("Received from IRI Probe ", iri_probe_name, ": ", msg)

    negativeAnswer = False
    negativeAnswer_str = ""
    isIntercepted = False
    
    msg_nids = set()             # Important NIDs (INI2 + graph)
    msg_metadata_nids = set()    # Support metadata (INI2 only)
    msg_persistent_nids = set()  # Will not be send to MF (graph only)
    
    # Check if the message contains negative answer request
    # Is so, save negative answer string, set negative answer flag and strip the message
    if len(msg) == 8: # 3 lists + negative answer request
        negativeAnswer_str = msg[7]
        msg = msg[0:7]
        negativeAnswer = True
    if len(msg) == 6: # single list + negative answer request
        negativeAnswer_str = msg[5]
        msg = msg[0:5]
        negativeAnswer = True
    
    # Process (stripped) message    
    if len(msg) == 7: # 3 lists
        moduleID, timestamp, iriType, description, nidArray1, nidArray2, nidArray3 = msg
        # Extract nids from second (metadata nids) and third (persistent nids) list
        for nid in nidArray2:
            msg_metadata_nids.add(CreateNIDType(nid[1],nid[0]))
        for nid in nidArray3:
            msg_persistent_nids.add(CreateNIDType(nid[1],nid[0]))
        # Modify the message formatg
        if len(msg_metadata_nids) > 0:
            description = str(description) + " | " + (str(msg_metadata_nids).replace("'", "")).replace("\"","")
        nidArray = nidArray1 + nidArray3
        msg = (moduleID, timestamp, iriType, description, nidArray)
    elif len(msg) == 5: # single list
        moduleID, _,iriType, _,nidArray = msg
    else:
        log.info("old msg format (there is no description string)")
        # |modul id | ts | iri-type | [ ("MAC","mac"), ...] |
        moduleID,timeStamp,iriType,nidArray = msg
        msg = (moduleID, timeStamp, iriType, "", nidArray)
     
    # Create an array of NIDs we got from the message
    tmp_nids = set()
    for nid in nidArray:
        tmp_nids.add(CreateNIDType(nid[1],nid[0]))
    
    LIIDs = interceptTable.getAllLIIDs()

    # Newly added NIDs (IPv4/6 we got from a TCP connection)
    # => Wont be connected with each other in NID graph
    addedNids = dict()  
    
    # For each TCP or TCP3 create IPv4/6 NIDs if needed
    for nid in tmp_nids:
        msg_nids.add(nid)
        if nid.getType() == "TCP":
            if not nid.getClientIP() in msg_metadata_nids:
                msg_nids.add(nid.getClientIP())
                addedNids[nid.getClientIP()] = nid
            if not nid.getServerIP() in msg_metadata_nids:
                msg_nids.add(nid.getServerIP())
                addedNids[nid.getServerIP()] = nid
        elif nid.getType() == "TCP3":
            if not nid.getIP() in msg_metadata_nids:
                msg_nids.add(nid.getIP())
                addedNids[nid.getIP()] = nid
        elif nid.getType() == "E-mail message":
            msg_nids.add(nid.getSndEmail())
            addedNids[nid.getSndEmail()] = nid
            msg_nids.add(nid.getMsgID())
            addedNids[nid.getMsgID()] = nid
            rcpts = nid.getRcptEmails()
            for rcpt in rcpts:
                msg_nids.add(rcpt)
                addedNids[rcpt] = nid
    
    
    ################################################################
    # 1) Update NID graph (and inform MF about changes)            #
    ################################################################        
    # Add persistent NIDs
    if iriType == "BEGIN" or iriType == "CONTINUE":
        for nid in msg_persistent_nids:
            nidGraph.addPersistentNode(nid)
    
    # Update common (not LIID-specific) graph
    if len(msg_nids) >= 2 and iriType != "REPORT":    
        if updateNIDGraph(msg_nids, addedNids, moduleID, iriType):       
            nidTable = nidGraph.createNIDTable()
            writeIRIPTable(None, nidTable)
            
    # Update LIID-specific graphs
    for nid in msg_nids:
        for iceptLIID in LIIDs:
            liid_specific_nids = copy.deepcopy(nidGraph.getAllLIIDSpecificNodes(iceptLIID))
            for liid_specific_nid in liid_specific_nids:
                if (liid_specific_nid in nid or nid in liid_specific_nid) and liid_specific_nid != nid:
                    if iriType == "BEGIN" or iriType == "CONTINUE":
                        nidGraph.update_addConnection(liid_specific_nid, nid, "range", LIID=iceptLIID)
                    elif iriType == "END":
                        nidGraph.update_delConnection(liid_specific_nid, nid, "range", LIID=iceptLIID)
    
    # Filter out all NIDCCs from the message
    msg_nidCCs = {nidcc for nidcc in msg_nids if nidcc.isNIDCC()}
    
    if len(nidArray) == 0:
        return True
    
    ####################################################################
    # 2) Search through all NIDCCs:                                    #
    # Find out if NIDCC is related with an existing interception.      #
    # If so, resend IRI report for each relationship (LIID and NIDCC). #
    ####################################################################
    if not msg_nidCCs:
        return True # No NIDcc -> authentication message -> ignore
    # We have at least one NIDcc - not an authentication message
    # Search through all interceptions
    
    for iceptLIID in LIIDs:
        interception = interceptTable.getInterception(iceptLIID)
        iceptNID = CreateNID(interceptTable.getNID(iceptLIID))
        iceptLevel = interceptTable.getLevel(iceptLIID)
        # Get all related NIDCCs from NID graph
        icept_nids = nidGraph.nidLookup(iceptNID, iceptLevel, LIID=iceptLIID)
        icept_nidCCs = {nid for nid in icept_nids if nid.isNIDCC()}
        cons = nidGraph.findConnections_str(icept_nids)
        
        # Search through all NIDCCs from the message
        for msg_nidcc in msg_nidCCs:
            generate_iri = False
            # Check if NIDcc is related to the intercept
            for icept_nidCC in icept_nidCCs:
                if msg_nidcc in icept_nidCC:
                    generate_iri = True
                    isIntercepted = True
                    break
            if generate_iri:
                if iriType == "BEGIN":                                 # ** BEGIN **
                    if cinTable.recordExists(msg_nidcc,iceptLIID):     # Record in CIN table exists:
                        iriType = "CONTINUE"                           #    BEGIN -> CONTINUE 
                    else:                                              # Record in CIN table does not exist:
                        updateCIN(msg_nidcc, interception)             #    update CIN
                elif iriType == "CONTINUE":                            # ** CONTINUE **
                    if not cinTable.recordExists(msg_nidcc,iceptLIID): # Record in CIN table does not exist:
                        iriType = "BEGIN"                              #    CONTINUE -> BEGIN
                        updateCIN(msg_nidcc, interception)                 #    update CIN
                elif iriType == "END":                                 # ** END **
                    # kdyz je nidcc ve spolecnem grafu nebo kdyz je nidcc v grafu pro dane LIID
                    iriType = "CONTINUE"                               #    END -> CONTINUE
                # Generate the report
                cin = cinTable.getCIN(msg_nidcc, iceptLIID)
                sendIRIMessage(sm, iriType, interception, msg_nidcc, cin, msg, cons)
            cin = cinTable.getCIN(msg_nidcc, iceptLIID)
            if iriType == "END" and cin != None:
                sendIRIMessage(sm, iriType, interception, msg_nidcc, cin, msg, cons)
                cinTable.delete(msg_nidcc, iceptLIID)
 
    if negativeAnswer and not isIntercepted:
        s.send(negativeAnswer_str)
        
    return True

# end of socket handlers    
########################################################################
#

def updateNIDGraph(nids, addedNids, moduleID, iriType):
    """ Updates NID graph by NIDs we got from a message
        Makes correct connections even if we have added new NIDs (IPs from a TCP connection)
        Returns True if changes were made
    """
    changes = False
    couples = eachWithEach(set(nids))
    for couple in couples:
        couple_l = list(couple)
        NID1 = couple_l[0]
        NID2 = couple_l[1]
        if ((NID1 in addedNids and addedNids[NID1] != NID2) or
            (NID2 in addedNids and addedNids[NID2] != NID1)
        ):
            continue
        if iriType == "BEGIN" or iriType == "CONTINUE":
            # Adds a new connection (if it does not exist)
            if nidGraph.update_addConnection(NID1, NID2, moduleID):
                changes = True
        elif iriType == "END": 
            if nidGraph.update_delConnection(NID1, NID2, moduleID):
                changes = True
    return changes

def logEvent(eventText, severity = "info"):
    """ Adds given text to log

    severity - Level of the event
    """
    global logger
    getattr(logger, severity)(eventText)

def newIntercept(sm, interception):
    """ New intercept

    ini1ainterception - INI1AIntercept object
    sm Socket manager

    Returns if the IRI-IIF should continue
    """
    
    now = time.mktime(time.localtime())
    startTime = time.mktime(interception.getInterceptionStart())
    if startTime < now:
        activateIntercept(sm, interception)
    else:
        activationTime = startTime - now
        interceptQueue.push(activationTime, interception)
    return True

def deleteIntercept(sm, liid):
    """ Removes the interception identified by liid
    
    liid - LIID of the interception
    sm Socket manager

    Returns if the IRI-IIF should continue
    """
    
    nidGraph.deleteLIIDSpecificGraph(liid)
    interceptTable.remIntercept(liid)
    cinTable.deleteWithLIID(liid)
    logEvent("IRI-IIF: removing interception with LIID: %s" % liid, "info")
    return True

def activateIntercept(sm, interception):
    """ Activate a new interception """
    iceptLIID = interception.getLIID()
    if (interceptTable.liidExists(iceptLIID)):
        logEvent("IRI-IIF: Interception with LIID: %s already exists!" % interception.getLIID(), "info")
        return True     
    iceptNID = CreateNID(interception.getNID())
    
    ####################################################################
    # 1) Add the interception to intercept table                       #
    ####################################################################
    interceptTable.addIntercept(interception)
    logEvent("IRI-IIF: adding new interception: %s" % interception, "info")
    
    ####################################################################
    # 2) Add iceptNID to LIID-specific NID graph                       #
    ####################################################################
    nidGraph.addNode(iceptNID, LIID=iceptLIID)
    
    ####################################################################
    # 3) Add related connections to LIID-specific graph (if needed)    #
    ####################################################################
    for nid in nidGraph.getAllCommonNodes():
        if (iceptNID in nid or nid in iceptNID) and iceptNID != nid:
            nidGraph.update_addConnection(iceptNID, nid, "range", LIID=iceptLIID)
    
    ####################################################################
    # 4) Find all NIDCCs beloging to the interception (and its level)  #
    ####################################################################
    # Get all related NIDCCs from NID graph
    nids = nidGraph.nidLookup(iceptNID, interception.getLevel(), LIID=iceptLIID)
    nidCCs = {nid for nid in nids if nid.isNIDCC()}
    
    ####################################################################
    # 5) Generate and send BEGIN message for each NIDCC                #
    ####################################################################
    for nidcc in nidCCs:
        # Update CIN table
        updateCIN(nidcc, interception)
        # Generate the message
        msg = ("in progress", \
            time.time(), \
            "BEGIN", \
            "Intercept activated", \
            [(iceptNID.getType(), iceptNID)]
        )
        cin = cinTable.getCIN(nidcc, interception.getLIID())
        # Important: Every NID MUST be converted to string before sending to MF!
        cons = nidGraph.findConnections_str(nids)
        sendIRIMessage(sm, "BEGIN", interception, nidcc, cin, msg, cons)
    return True

def sendIRIMessage(sm, iriType, interception, nidcc, cin, msg, cons):
    # Create IRI message
    report = createIRIReport(iriType, interception, nidcc, cin, msg, cons)
    # Send the message to MF
    log.info("SENDING IRI " + iriType + ": ", report)
    sm.send("ini2", ("iri_report", "mf", report))

def createIRIReport(iriType, interception, nidcc, cin, msg, nids):
    """ Generates IRI Report with given parameters """
    cid = copy.deepcopy(interception.getCID())
    cid.setCIN(cin)
    report = (iriType, \
        interception.getLIID(), \
        cid, \
        str(nidcc), \
        msg, \
        tuple(nids)
    )
    return report
    
def updateCIN(nidcc, interception):
    """ Updates CIN
        Adds new or updates existing record in CIN table with next CIN value.
        Increases CIN value in interception table.
    """
    nextCIN = interceptTable.getCIN(interception.getLIID()) + 1 
    cinTable.update(nidcc, interception.getLIID(), nextCIN)
    interceptTable.setCIN(interception.getLIID(), nextCIN)
    
def eachWithEach(nidSet):
    """ Helper function to get couples (each with each other)
        of tuples from a set od NIDs.
    """
    couples = set()
    for nid in nidSet:
        for other in nidSet:
            if nid != other:
                couple = frozenset([nid, other])
                couples.add(couple)
    return couples

if __name__ == "__main__":
    try:
        main(sys.argv)
    except Exception as e:
        log.unhandledException("IRI-IIF", e)

