"""
    Nazev souboru: atomic_attack_processor.py
    Autor: Jindrich Dudek (xdudek04)
    Email: xdudek04@stud.fit.vutbr.cz
    Datum posledni modifikace: 05/18/2018
    Verze Pythonu: 2.7 (2.7.12)
"""

import filter as F
import error_checker as EC
from collections import defaultdict


class AtomicAttackProcessor:
    def __init__(self, is_structured):
        self.__filter = F.Filter()
        self.__checker = EC.ErrorChecker()
        self.__is_structured = is_structured

    def process(self, desc, packets):
        """
        Main method for description interpretation.
        :param packets: Parsed filtered packets.
        :param desc: Parsed YAML description of the attack
        """

        if len(packets) == 0:
            if self.__is_structured:
                print('\t<attack name="%s" type="atomic" '
                      'detection="negative" />' % desc['name'])
            else:
                print('Attack "%s" was not detected.' % desc['name'])
            return

        det_conditions = desc['detection-conditions']
        if det_conditions['type'] == 'and':
            self.__and_conditions_processing(packets, desc)
        else:
            self.__or_conditions_processing(packets, desc)

    def __or_conditions_processing(self, packets, desc):
        """
        Method that calls methods for processing of every condition in the
        description in case parameter "type" has value "or".
        :param packets: List with filtered packets.
        :param desc: Parsed YAML description.
        """
        # List with conditions:
        conditions = desc['detection-conditions']['conditions']
        # Indicator if method which reports attack should also write packets
        # where attack was found:
        packet_report = False
        pos_det = []  # List with positive detections

        # For every condition:
        for cond in conditions:
            if cond['condition-type'] == 'expression':
                # If all variables in expression are abstract:
                if self.__all_variables_abstract(cond['variables']):
                    # If expression was evaluated as True:
                    if self.__evaluate_abstract_expression(packets, cond):
                        pos_det = packets  # Just to indicate detection
                        packet_report = False
                # If abstract and specific variables are combined in expression
                # or all of them are specific:
                else:
                    pos_det = self.__evaluate_combined_expression(packets, cond)
                    packet_report = True
            elif cond['condition-type'] == 'field-value':
                value_type = cond.get('value-type')
                if value_type == 'specific':
                    pos_det = self.__get_specific_field_value_packets(packets,
                                                                     cond)
                    packet_report = True
                elif value_type == 'abstract':
                    if self.__evaluate_abstract_field_value(packets, cond):
                        pos_det = packets  # Just to indicate detection
                        packet_report = False
                # value_type == None, simple rule with one attr "field-name":
                else:
                    pos_det = self.__get_packets_by_field_name(packets, cond)
                    packet_report = True
            elif cond['condition-type'] == 'field-count':
                pos_det = self.__get_packets_by_field_count(packets, cond)
                packet_report = True
            elif cond['condition-type'] == 'packet-ratio':
                if self.__evaluate_packet_ratio_cond(packets, cond):
                    pos_det = packets  # Just to indicate detection
                    packet_report = False

            if len(pos_det) > 0:
                threshold_err = cond.get('threshold-error')
                threshold_warn = cond.get('threshold-warning')
                if self.__is_structured:
                    det = self.__check_detection_structured(pos_det,
                                                            desc['name'],
                                                            threshold_err,
                                                            threshold_warn,
                                                            packet_report)
                else:
                    det = self.__check_detection(pos_det, desc['name'],
                                                 threshold_err, threshold_warn,
                                                 packet_report)
                if det:  # If attack was detected
                    return

        # If attack was not detected:
        if self.__is_structured:
            print('\t<attack name="%s" type="atomic" detection="negative" />'
                  % desc['name'])
        else:
            print('Attack "%s" was not detected.' % desc['name'])

    def __and_conditions_processing(self, packets, desc):
        """
        Method that calls methods for processing of every condition in the
        description in case parameter "type" has value "and".
        :param packets: List with filtered packets.
        :param desc: Parsed YAML description.
        """
        # List with conditions
        conditions = desc['detection-conditions']['conditions']
        # Indicator if method which reports attack should also write packets
        # where attack was found:
        packet_report = False

        # Initializing list of positive detections
        pos_det = packets

        # Iterating over conditions
        for cond in conditions:
            if cond['condition-type'] == 'expression':
                # If all variables in expression are abstract:
                if self.__all_variables_abstract(cond['variables']):
                    detected = self.__evaluate_abstract_expression(packets,
                                                                   cond)
                    if not detected:  # Clear the list to indicate no detection
                        pos_det = []
                # If abstract and specific variables are combined in expression
                # or all of them are specific:
                else:
                    detections = self.__evaluate_combined_expression(packets,
                                                                     cond)
                    # Intersect list with current positive detections and newly
                    # discovered detections
                    pos_det = self.__intersect(pos_det, detections)
                    packet_report = True
            elif cond['condition-type'] == 'field-value':
                value_type = cond.get('value-type')
                if value_type == 'specific':
                    pos_det = self.__get_specific_field_value_packets(pos_det,
                                                                     cond)
                    packet_report = True
                elif value_type == 'abstract':
                    if not self.__evaluate_abstract_field_value(packets, cond):
                        pos_det = []  # Just to indicate attack was not detected
                # value_type == None, simple rule with one attr "field-name":
                else:
                    pos_det = self.__get_packets_by_field_name(pos_det, cond)
                    packet_report = True
            elif cond['condition-type'] == 'field-count':
                pos_det = self.__get_packets_by_field_count(pos_det, cond)
                packet_report = True
            elif cond['condition-type'] == 'packet-ratio':
                if not self.__evaluate_packet_ratio_cond(packets, cond):
                    pos_det = []  # Clear list to indicate that cond is False

            if len(pos_det) == 0:  # No need to continue if there are no packets
                break

        threshold_err = desc.get('threshold-error')
        threshold_warn = desc.get('threshold-warning')
        if self.__is_structured:
            self.__report_attack_structured(pos_det, desc['name'],
                                            threshold_err, threshold_warn,
                                            packet_report)
        else:
            self.__report_attack(pos_det, desc['name'], threshold_err,
                                 threshold_warn, packet_report)

    def __evaluate_packet_ratio_cond(self, packets, cond):
        """
        Method that evaluates condition of type "packet-ratio".
        :param packets: List of input packets
        :param cond: Parsed YAML description of condition.
        :return: True if condition is satisfied, False otherwise.
        """
        # Filtering of first group of packets
        properties = cond['packets'][0]['properties']
        fst_grp = self.__filter.process(packets, properties)

        # Filtering of second group of packets
        properties = cond['packets'][1]['properties']
        snd_grp = self.__filter.process(packets, properties)

        if len(snd_grp) == 0:
            ratio = float("inf")  # Infinity
        else:
            ratio = len(fst_grp) / float(len(snd_grp))

        return ratio >= cond['ratio']

    def __get_packets_by_field_count(self, packets, cond):
        """
        Method that evaluates condition of type "field-count" and returns
        all packets that satisfy condition.
        :param packets: List of input packets.
        :param cond: Parsed YAML description of condition.
        :return: List of packets that satisfy condition.
        """
        xpath = './/*[@name="%s"]' % cond['field-name']
        pos_det = []

        for packet in packets:
            fields = packet.findall(xpath)
            if len(fields) == cond['count']:
                pos_det.append(packet)

        return pos_det

    def __evaluate_abstract_field_value(self, packets, cond):
        """
        Method that evaluates condition of type "field-value" with attribute
        value-type set to "abstract".
        :param packets: List of packets in input file.
        :param cond: Parsed YAML description of condition.
        :return: True if condition is satisfied, False otherwise.
        """
        value = self.__get_abstract_value(packets, cond)
        return value >= cond['count']

    def __get_specific_field_value_packets(self, packets, cond):
        """
        Method that returns packets that satisfy condition "field-value"
        with "specific" attribute "value-type".
        :param packets: List of input packets.
        :param cond: Parsed YAML condition.
        :return: List of packets with positive detection.
        """
        pos_det = []  # List of packets with positive detection

        for packet in packets:
            attr_val = self.__get_field_value(packet,
                                                    cond['field-name'])
            if attr_val is None:  # Attribute was not found in packet
                continue

            if attr_val == cond['value']:
                pos_det.append(packet)

        return pos_det

    def __get_packets_by_field_name(self, packets, cond):
        """
        Method that returns all packets which contain field with specified name.
        :param packets: List of input packets.
        :param cond: Parsed YAML condition.
        :return: List of packets with positive detection.
        """
        xpath = './/*[@name="%s"]' % cond['field-name']
        pos_det = []  # Initialization of list with positive detections:

        for packet in packets:
            field = packet.find(xpath)
            if field is not None:
                pos_det.append(packet)

        return pos_det

    def __evaluate_abstract_expression(self, packets, cond):
        """
        Method for evaluation of expression where are used only abstract
        variables.
        :param packets: List of packets.
        :param cond: Parsed YAML description of condition.
        :return: True if expression is evaluated as True, False otherwise.
        """
        variables = cond['variables']
        expression = cond['expression']

        # Calculate value of every variable and replace it in
        for var in variables:
            # Variable of type "filtered-frame-count"
            if var['type'] == 'filtered-frame-count':
                expression = expression.replace(var['name'], str(len(packets)))
            elif var['type'] == 'field-value' \
                    and var['value-type'] == 'abstract':
                val = self.__get_abstract_value(packets, var)
                expression = expression.replace(var['name'], str(val))
        # Evaluate expression
        try:
            return eval(expression)  # Evaluate the expression
        except:
            self.__checker.report_error('Expression in condition of type '
                                        '"expression" can not be evaluated.')
            return False

    def __get_abstract_value(self, packets, desc):
        """
        Method that gets value of  condition type "field-value" with attribute
        value-type set to "abstract".
        :param packets: List of packets to get result from.
        :param desc: Parsed YAML description of condition.
        :return:
        """
        # How many different values of specified attribute are in packets:
        if desc['value'] == 'different':
            unique_values = set()  # Initialization of empty set
            for packet in packets:
                val = self.__get_field_value(packet,
                                                   desc['field-name'])
                if val is None:
                    continue

                unique_values.add(val)
            # Return count unique values in set
            return len(unique_values)

        else:  # var['value'] == 'same'
            occurrences = defaultdict(int)  # Initialize empty dictionary
            for packet in packets:
                val = self.__get_field_value(packet,
                                                   desc['field-name'])
                if val is None:
                    continue
                occurrences[val] += 1
            # Get key with maximum value in dictionary
            max_key = max(occurrences, key=occurrences.get)
            # Return the largest number in occurrences
            return occurrences[max_key]

    def __get_field_value(self, packet, field_name):
        """
        Method that returns value of specified field.
        :param packet: Packet where specified field should be.
        :param field_name: Name of field.
        :return: Value of specified field or None if not present.
        """
        attribute = 'value'
        xpath = './/*[@name="%s"]' % field_name  # XPath to get specified field
        field = packet.find(xpath)
        if field is None:  # If field was not found in packet
            return None

        val = field.get(attribute)  # Get value of attribute
        return val

    def __evaluate_combined_expression(self, packets, cond):
        """
        Method for evaluation of expression where are used abstract variables
        combined with specific variables, or only specific variables.
        :param packets: List of packets.
        :param cond: Parsed YAML with condition.
        :return: Packet list with positive detection
        """
        value = None
        pos_det = []  # List of packets with positive evaluation of expression
        expression = cond['expression']
        spec_vars = []  # List of specific variables

        # Calculate abstract variables and save specific
        for var in cond['variables']:
            # Count abstract values and replace it in expression
            if var['type'] == 'filtered-frame-count':
                expression = expression.replace(var['name'], str(len(packets)))
            elif var['type'] == 'field-value' \
                    and var['value-type'] == 'abstract':
                val = self.__get_abstract_value(packets, var)
                expression = expression.replace(var['name'], str(val))
            # Specific variables - add it to list
            else:
                spec_vars.append(var)

        for packet in packets:  # For every packet
            prep_expr = expression  # Expression with calculated abstract values
            for var in spec_vars:
                value = self.__get_field_value(packet,
                                                     var['field-name'])
                if value is None:  # If value was not found
                    break
                # Check if value is integer value and replace variable name
                # with found value with or without quotes around the value
                if self.__is_int(value):
                    prep_expr = prep_expr.replace(var['name'], value)
                else:
                    prep_expr = prep_expr.replace(var['name'], '"%s"' % value)

            # If variable value was not found there is no reason to evaluate
            # the expression, continue with next packet
            if value is None:
                continue
            try:
                result = eval(prep_expr)  # Evaluate the expression
                if result:  # Add it to positive detection list if true
                    pos_det.append(packet)
            except:
                self.__checker.report_error('Expression in condition of type '
                                            '"expression" can not be evaluated.')
                return []

        return pos_det

    def __all_variables_abstract(self, variables):
        """
        Method that indicates whether all variables in list are abstract.
        Values of abstract variables are gained from more than one packet.
        :param variables: List with variables description.
        :return: True if all variables in list are abstract, False otherwise.
        """
        for v in variables:
            if v['type'] == 'field-value' and v['value-type'] == 'specific':
                return False

        return True

    def __is_int(self, s):
        """
        Check if string can be converted to integer.
        :param s: String.
        :return: True if string can be converted. False otherwise.
        """
        try:
            int(s)
            return True
        except ValueError:
            return False

    def __report_attack(self, pos_det, attack_name, threshold_err,
                        threshold_warn, packet_report):
        """
        Method that reports if attack was detected or not.
        :param pos_det: Packets with positive detection.
        :param attack_name: Name of the attack.
        :param threshold_err: Threshold for report ERROR.
        :param threshold_warn: Threshold for report WARNING
        :param packet_report: True if report should write packets with positive
                              detection, False otherwise.
        """

        error = 'Attack "%s" was detected. Level: ERROR.' % attack_name
        warn = 'Attack "%s" was detected. Level: WARNING.' % attack_name
        not_found = 'Attack "%s" was not detected.' % attack_name

        if len(pos_det) == 0:
            print(not_found)
            return

        if not packet_report:  # No packet report
            print(warn)
            return

        packet_numbers = self.__get_packet_numbers(pos_det)

        packet_report = 'Found %d positive detection(s) in packet(s) number: %s'\
                        % (len(pos_det), ', '.join(packet_numbers))

        # If both thresholds are present in description
        if threshold_err is not None and threshold_warn is not None:
            if len(pos_det) >= threshold_err:
                print(error)
                print(packet_report)
            elif len(pos_det) >= threshold_warn:
                print(warn)
                print(packet_report)
            else:
                print(not_found)
        # If only threshold-error was found in description
        elif threshold_err is not None:
            if len(pos_det) >= threshold_err:
                print(error)
                print(packet_report)
            else:
                print(not_found)
        # If only threshold-warning was found in description
        elif threshold_warn is not None:
            if len(pos_det) >= threshold_warn:
                print(warn)
                print(packet_report)
            else:
                print(not_found)

    def __report_attack_structured(self, pos_det, attack_name, threshold_err,
                                   threshold_warn, packet_report):
        """
        Method that reports if attack was detected or not in XML format.
        :param pos_det: Packets with positive detection.
        :param attack_name: Name of the attack.
        :param threshold_err: Threshold for report ERROR.
        :param threshold_warn: Threshold for report WARNING
        :param packet_report: True if report should write packets with positive
                              detection, False otherwise.
        """

        error_opening = '\t<attack name="%s" type="atomic" ' \
                        'detection="positive" level="error">' % attack_name
        warn_opening = '\t<attack name="%s" type="atomic" ' \
                       'detection="positive" level="warning">' % attack_name
        closing = '\t</attack>'
        warn_noreport = '\t<attack name="%s" type="atomic" ' \
                        'detection="positive" level="warning" />' % attack_name
        not_found = '\t<attack name="%s" type="atomic" detection="negative" />'\
                    % attack_name

        packet_numbers = self.__get_packet_numbers(pos_det)

        if len(pos_det) == 0:
            print(not_found)
            return

        if not packet_report:  # No packet report
            print(warn_noreport)
            return

        if threshold_err is not None and threshold_warn is not None:
            if len(pos_det) >= threshold_err:
                print(error_opening)
                self.__print_detections_structured(packet_numbers)
                print(closing)
            elif len(pos_det) >= threshold_warn:
                print(warn_opening)
                self.__print_detections_structured(packet_numbers)
                print(closing)
            else:
                print(not_found)
        # If only threshold-error was found in description
        elif threshold_err is not None:
            if len(pos_det) >= threshold_err:
                print(error_opening)
                self.__print_detections_structured(packet_numbers)
                print(closing)
            else:
                print(not_found)
        # If only threshold-warning was found in description
        elif threshold_warn is not None:
            if len(pos_det) >= threshold_warn:
                print(warn_opening)
                self.__print_detections_structured(packet_numbers)
                print(closing)
            else:
                print(not_found)

    def __print_detections_structured(self, packet_numbers):
        """
        Method that reports packets that were marked as positive detection in
        XML format.
        :param packet_numbers: Numbers of packets that were marked as positive
                               detection.
        """
        for number in packet_numbers:
            print('\t\t<packet num="%s"/>' % number)

    def __check_detection(self, pos_det, attack_name, threshold_err,
                          threshold_warn, packet_report):
        """
        Method that compares count of positive detections with thresholds and
        reports attack in case it was detected.
        :param pos_det: Packets with positive detection.
        :param attack_name: Name of the attack.
        :param threshold_err: Threshold for report ERROR.
        :param threshold_warn: Threshold for report WARNING
        :param packet_report: True if report should write packets with positive
                              detection, False otherwise.
        :return True if attack was detected, False otherwise.
        """
        error = 'Attack "%s" was detected. Level: ERROR.' % attack_name
        warn = 'Attack "%s" was detected. Level: WARNING.' % attack_name

        if not packet_report:  # No packet report
            print(warn)
            return True

        packet_numbers = self.__get_packet_numbers(pos_det)
        packet_report = 'Found %d positive detection(s) in packet(s) number: %s'\
                        % (len(pos_det), ', '.join(packet_numbers))

        # If both thresholds are present in description
        if threshold_err is not None and threshold_warn is not None:
            if len(pos_det) >= threshold_err:
                print(error)
                print(packet_report)
                return True
            elif len(pos_det) >= threshold_warn:
                print(warn)
                print(packet_report)
                return True
            else:
                return False
        # If only threshold-error was found in description
        elif threshold_err is not None:
            if len(pos_det) >= threshold_err:
                print(error)
                print(packet_report)
                return True
            else:
                return False
        # If only threshold-warning was found in description
        elif threshold_warn is not None:
            if len(pos_det) >= threshold_warn:
                print(warn)
                print(packet_report)
                return True
            else:
                return False

    def __check_detection_structured(self, pos_det, attack_name, threshold_err,
                                     threshold_warn, packet_report):
        """
        Method that compares count of positive detections with thresholds and
        reports attack in case it was detected in XML format.
        :param pos_det: Packets with positive detection.
        :param attack_name: Name of the attack.
        :param threshold_err: Threshold for report ERROR.
        :param threshold_warn: Threshold for report WARNING
        :param packet_report: True if report should write packets with positive
                              detection, False otherwise.
        :return True if attack was detected, False otherwise.
        """
        error_opening = '\t<attack name="%s" type="atomic" ' \
                        'detection="positive" level="error">' % attack_name
        warn_opening = '\t<attack name="%s" type="atomic" ' \
                       'detection="positive" level="warning">' % attack_name
        closing = '\t</attack>'
        warn_noreport = '\t<attack name="%s" type="atomic" ' \
                        'detection="positive" level="warning" />' % attack_name
        not_found = '\t<attack name="%s" type="atomic" detection="negative" />'\
                    % attack_name

        if not packet_report:  # No packet report
            print(warn_noreport)
            return True

        packet_numbers = self.__get_packet_numbers(pos_det)

        # If both thresholds are present in description
        if threshold_err is not None and threshold_warn is not None:
            if len(pos_det) >= threshold_err:
                print(error_opening)
                self.__print_detections_structured(packet_numbers)
                print(closing)
                return True
            elif len(pos_det) >= threshold_warn:
                print(warn_opening)
                self.__print_detections_structured(packet_numbers)
                print(closing)
                return True
            else:
                return False
        # If only threshold-error was found in description
        elif threshold_err is not None:
            if len(pos_det) >= threshold_err:
                print(error_opening)
                self.__print_detections_structured(packet_numbers)
                print(closing)
                return True
            else:
                return False
        # If only threshold-warning was found in description
        elif threshold_warn is not None:
            if len(pos_det) >= threshold_warn:
                print(warn_opening)
                self.__print_detections_structured(packet_numbers)
                print(closing)
                return True
            else:
                return False

    def __get_packet_numbers(self, pos_det):
        """
        Method that returns list of packet's numbers.
        :param pos_det: List of packets.
        :return: List of packets numbers.
        """
        numbers = []
        xpath = './/*[@name="num"]'
        for packet in pos_det:
            elem = packet.find(xpath)
            numbers.append(elem.get('show'))
        return numbers

    def __intersect(self, a, b):
        """
        Method that returns intersection of two lists.
        :param a: First list.
        :param b: Second list.
        :return: Intersected list.
        """
        return list(set(a) & set(b))






