"""
    Nazev souboru: na_detector.py
    Autor: Jindrich Dudek (xdudek04)
    Email: xdudek04@stud.fit.vutbr.cz
    Datum posledni modifikace: 05/18/2018
    Verze Pythonu: 2.7 (2.7.12)
"""

import sys
import os
import argparse
import error_checker as EC
import filter as F
import atomic_attack_processor as AP
import stream_group_attack_processor as SG
import yaml
import xml.etree.cElementTree as ET
from collections import defaultdict
#from memory_profiler import profile


class NetworkAttacksDetector:
    def __init__(self, options):
        directory_path = options.directory
        self.__is_structured = options.structured
        self.__file_path = options.input
        self.__checker = EC.ErrorChecker()
        self.__atomic_processor = AP.AtomicAttackProcessor(self.__is_structured)
        self.__stream_group_processor = SG.StreamGroupAttackProcessor(
                                                           self.__is_structured)
        self.__filter = F.Filter()
        # List of parsed descriptions, filled in method parse_descriptions:
        self.__descriptions = []
        # Packets divided into groups for each attack:
        self.__packets_by_attack = defaultdict(list)
        # Check if directory and file exists
        self.__checker.check_arguments(directory_path, self.__file_path)
        # Get content of provided directory
        self.__desc_list = self.__get_description_files(directory_path)

    def execute(self):
        """
        Main method that calls attack processing.
        """
        if self.__is_structured:  # Print opening report tag
            print('<?xml version="1.0" encoding="UTF-8"?>')
            print('<report file="%s">' % self.__file_path)

        for desc_tuple in self.__descriptions:  # For every attack description
            file_name = desc_tuple[0]
            desc = desc_tuple[1]

            if desc['scope'] == 'atomic':
                self.__atomic_processor.process(desc,
                                                self.__packets_by_attack
                                                [file_name])
            elif desc['scope'] == 'stream':
                self.__stream_group_processor.process(desc,
                                                      self.__packets_by_attack
                                                      [file_name])
            elif desc['scope'] == 'group':
                self.__stream_group_processor.process(desc,
                                                      self.__packets_by_attack
                                                      [file_name])
        if self.__is_structured:  # Print closing report tag
            print('</report>')

    def convert_and_parse(self, tmp_filename):
        """
        Converts input file with network communication to PDML file format.
        :param tmp_filename: Name of temporary file with PDML content.
        """

        fields = self.__get_required_fields()
        args = ' '.join(['-e %s' % field for field in fields])
        display_filter = self.__create_display_filter()
        # Create command for conversion:
        command = 'tshark -r %s -T pdml %s %s > %s' % \
                  (self.__file_path, args, display_filter, tmp_filename)

        ret_val = os.system(command)  # Execute command
        if ret_val != 0:
            self.__checker.report_error("Cannot convert input file to "
                                        "PDML file format.")
            sys.exit(self.__checker.error_codes.CONVERSION_ERR)

        self.__get_required_packets(tmp_filename)
        os.remove(tmp_filename)

    def __create_display_filter(self):
        """
        Method that creates string representing display filter for tshark to
        include only required packets to converted PDML file.
        :return: String representing display filter.
        """
        partial_filters = set()

        for desc_tuple in self.__descriptions:
            desc = desc_tuple[1]
            # If at least one description has no properties attribute
            # we need all packets:
            if 'properties' not in desc:
                return ''  # No display filter
            filters = []
            for prop in desc['properties']:
                if len(prop) == 2:  # Simple rule
                    if prop['valid']:
                        rule = prop['field-name']
                        filters.append(rule)
                    else:
                        rule = '!%s' % prop['field-name']
                        filters.append(rule)
                elif len(prop) == 3:  # Complex rule
                    if prop['valid']:
                        rule = '%s==%s' % (prop['field-name'], prop['value'])
                        filters.append(rule)
                    else:
                        rule = '%s!=%s' % (prop['field-name'], prop['value'])
                        filters.append(rule)
            # Concatenate filters for one description
            comp_filter = ' && '.join([filt for filt in filters])
            # Surround it with braces
            comp_filter = '(%s)' % comp_filter
            partial_filters.add(comp_filter)

        # Concatenate filters for all descriptions
        display_filter = ' || '.join(filt for filt in partial_filters)
        display_filter = '-Y "%s"' % display_filter

        return display_filter

    def __get_required_packets(self, file_name):
        """
        Method that gets text lines from PDML file that represents exactly one
        packet.
        :param file_name: Name of converted PDML file.
        """
        buffer = []
        with open(file_name, 'rb') as input_file:
            append = False
            for line in input_file:
                if line == '<packet>\n':
                    buffer.append(line)
                    append = True
                elif line == '</packet>\n':
                    buffer.append(line)
                    append = False
                    packet_str = ''.join([l for l in buffer])
                    self.__classify_packet(packet_str)
                    buffer = []
                elif append:
                    buffer.append(line)

    def __classify_packet(self, packet_str):
        """
        Method that creates list of packets that are necessary for each attack.
        :param packet_str: String that contains packet for parsing.
        :return:
        """
        packet = ET.fromstring(packet_str)  # Parse packet
        for desc_tuple in self.__descriptions:
            file_name = desc_tuple[0]
            desc = desc_tuple[1]
            #  If properties section is not present, add packet
            if 'properties' not in desc:
                self.__packets_by_attack[file_name].append(packet)
                continue

            # Check if packet is required for attack
            if self.__filter.is_packet_required(packet, desc['properties']):
                self.__packets_by_attack[file_name].append(packet)

    def __get_description_files(self, directory_path):
        """
        Method that returns list of absolute paths to files with attacks
        descriptions in provided directory.
        :param directory_path: Path to directory.
        :return: List with absolute paths.
        """
        # Get absolute path of provided folder
        absolute_path = os.path.abspath(directory_path)

        # Get all files from directory and return list of their absolute paths
        desc_list = os.listdir(absolute_path)
        desc_list = [absolute_path + "/" + x for x in desc_list]
        return desc_list

    def parse_descriptions(self):
        """
        Method that returns a list of parsed attack descriptions.
        :return: List with the parsed attack descriptions.
        """
        for desc_path in self.__desc_list:
            try:
                desc = open(desc_path, "r")
                desc = yaml.load(desc)
            except (IOError, yaml.YAMLError):
                self.__checker.report_error('Cannot process YAML file: %s'
                                            % desc_path)
                continue

            self.__descriptions.append((desc_path, desc))

    def check_syntax(self):
        """
        Method that checks correct syntax of every attack description.
        """
        correct_descriptions = []
        error_msg = "This error was found in the \"%s\" attack description.\n"

        # Check syntax of every description
        for desc_tuple in self.__descriptions:
            file_name = desc_tuple[0]
            desc = desc_tuple[1]

            if not isinstance(desc, dict):
                self.__checker.report_error('Description of the attack in file '
                                            '%s should be dictionary '
                                            '(structure).' % file_name)
                continue
            # Check if attribute "name" is present in YAML description
            if 'name' not in desc or not isinstance(desc['name'], str):
                self.__checker.report_error('Attribute "name" is not present '
                                            'in the description located in '
                                            'file %s or its value is not a '
                                            'string.' % file_name)
                continue
            # Check if attribute "scope" is present in YAML description
            if not ('scope' in desc):
                self.__checker.report_error('Attribute "scope" in the '
                                            'description of attack "%s" '
                                            'was not found.' % desc['name'])
                continue
            # Check if all required attributes are present
            if desc['scope'] == 'atomic':
                if not self.__checker.check_atomic_syntax(desc):
                    sys.stderr.write(error_msg % desc['name'])
                    continue
            elif desc['scope'] == 'stream':
                if not self.__checker.check_stream_syntax(desc):
                    sys.stderr.write(error_msg % desc['name'])
                    continue
            elif desc['scope'] == 'group':
                if not self.__checker.check_group_syntax(desc):
                    sys.stderr.write(error_msg % desc['name'])
                    continue
            else:
                self.__checker.report_error('Attribute "scope" in the '
                                            'description of attack "%s" has '
                                            'unsupported value.' % desc['name'])
                continue

            correct_descriptions.append(desc_tuple)
        # __descriptions will only contain correct descriptions
        self.__descriptions = correct_descriptions

    def __get_required_fields(self):
        """
        Method that gets a list of required fields from every attack description
        that are necessary for the attack detection.
        :return: List of nescessary fields.
        """
        fields = {'tcp.stream', 'udp.stream'}  # Initialization of field set

        for desc_tuple in self.__descriptions:
            desc = desc_tuple[1]
            if 'properties' in desc:  # Add fields from properties
                for prop in desc['properties']:
                    fields.add(prop['field-name'])

            if desc['scope'] == 'atomic':
                self.__get_atomic_required_fields(desc, fields)
            elif desc['scope'] == 'stream':
                self.__get_stream_required_fields(desc, fields)
            elif desc['scope'] == 'group':
                self.__get_group_required_fields(desc, fields)

        return fields

    def __get_atomic_required_fields(self, desc, fields):
        """
        Method that gets required fields from attack description of type atomic.
        :param desc: Description with scope attribute set to atomic.
        :param fields: List of required fields.
        """
        conditions = desc['detection-conditions']['conditions']

        for cond in conditions:
            if cond['condition-type'] == 'field-value':
                fields.add(cond['field-name'])
            elif cond['condition-type'] == 'packet-ratio':
                for packet in cond['packets']:
                    for prop in packet['properties']:
                        fields.add(prop['field-name'])
            elif cond['condition-type'] == 'field-count':
                fields.add(cond['field-name'])
            elif cond['condition-type'] == 'expression':
                for var in cond['variables']:
                    if var['type'] == 'field-value':
                        fields.add(var['field-name'])

    def __get_stream_required_fields(self, desc, fields):
        """
        Method that gets required fields from attack description of type stream.
        :param desc: Description with scope attribute set to stream.
        :param fields: List of required fields.
        """
        packets_spec = desc['packets-specification']

        for spec in packets_spec['specification']:
            properties = spec['packet-properties']
            for prop in properties:
                fields.add(prop['field-name'])

    def __get_group_required_fields(self, desc, fields):
        """
        Method that gets required fields from attack description of type group.
        :param desc: Description with scope attribute set to group.
        :param fields: List of required fields.
        """
        self.__get_stream_required_fields(desc, fields)

        for grp in desc['group-by']:
            fields.add(grp['field-name'])

def get_arguments():
    """
    Function for script parameters specification and their processing
    :return: Options set by user.
    """

    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--directory", help="Specifies directory with "
                        "network attacks description.", required=True)
    parser.add_argument("-i", "--input", help="Specifies input file with "
                        "captured network communication", required=True)
    parser.add_argument("-s", "--structured", help="Enables structured report " 
                        "of attacks detection in XML format.", required=False,
                        action='store_true')

    options = parser.parse_args()
    return options


if __name__ == "__main__":
    options = get_arguments()
    det = NetworkAttacksDetector(options)
    det.parse_descriptions()  # Parse descriptions in provided folder
    det.check_syntax()  # Check syntax of parsed descriptions
    det.convert_and_parse('tmp.xml')
    det.execute()
