"""
    Nazev souboru: stream_group_attack_processor.py
    Autor: Jindrich Dudek (xdudek04)
    Email: xdudek04@stud.fit.vutbr.cz
    Datum posledni modifikace: 05/18/2018
    Verze Pythonu: 2.7 (2.7.12)
"""

import filter as F
import error_checker as EC
from collections import defaultdict
import copy as CP


class StreamGroupAttackProcessor:
    def __init__(self, is_structured):
        self.__filter = F.Filter()
        self.__checker = EC.ErrorChecker()
        self.__is_structured = is_structured

    def process(self, desc, packets):
        """"
        Main method for description interpretation.
        :param packets: Filtered packets required for attack detection.
        :param desc: Parsed YAML description of the attack
        :return:
        """
        # If properties are present in description, filter input packets
        if len(packets) == 0:
            if self.__is_structured:
                print('\t<attack name="%s" type="%s" '
                      'detection="negative" />' % (desc['name'], desc['scope']))
            else:
                print('Attack "%s" was not detected.' % desc['name'])
            return

        if desc['scope'] == 'stream':
            tcp_streams, udp_streams = self.__packets_to_streams(packets)
            pos_det_tcp = self.__scan_packet_groups(tcp_streams, desc)
            pos_det_udp = self.__scan_packet_groups(udp_streams, desc)

            pos_det = [('tcp', str(i)) for i in pos_det_tcp]
            pos_det += [('udp', str(i)) for i in pos_det_udp]

            threshold_warn = desc.get('threshold-warning')
            threshold_err = desc.get('threshold-error')
            if self.__is_structured:
                self.__report_stream_attack_structured(pos_det, threshold_warn,
                                                       threshold_err,
                                                       desc['name'])
            else:
                self.__report_stream_attack(pos_det, threshold_warn,
                                            threshold_err, desc['name'])
        else:  # Group
            groups = self.__packets_to_groups(packets, desc)
            pos_det = self.__scan_packet_groups(groups, desc)
            threshold_warn = desc.get('threshold-warning')
            threshold_err = desc.get('threshold-error')
            if self.__is_structured:
                self.__report_group_attack_structured(pos_det, threshold_warn,
                                                      threshold_err,
                                                      desc['name'], groups)
            else:
                self.__report_group_attack(pos_det, threshold_warn,
                                           threshold_err, desc['name'], groups)

    def __report_stream_attack(self, pos_det, threshold_warn, threshold_err,
                               attack_name):
        """
        Method that compares count of positive detections with thresholds and
        reports attack "stream" in case the conditions are satisfied.
        :param pos_det: List of tuples with stream type and indices of streams
                        that were marked as possitive detection.
        :param threshold_warn: Value of attribute "threshold-warning".
        :param threshold_err: Value of attribute "threshold-error".
        :param attack_name: String with attack name.
        """
        error = 'Attack "%s" was detected. Level: ERROR.' % attack_name
        warn = 'Attack "%s" was detected. Level: WARNING.' % attack_name

        pos_det_str = [i[0] + ": " + str(i[1]) for i in pos_det]
        stream_report = 'Found %d positive detection(s) in stream(s) number: ' \
                        '%s' % (len(pos_det), ', '.join(pos_det_str))

        # If both thresholds are present in description:
        if threshold_warn is not None and threshold_err is not None:
            if len(pos_det) >= threshold_err:
                print(error)
                print(stream_report)
                return
            elif len(pos_det) >= threshold_warn:
                print(warn)
                print(stream_report)
                return
        # If only threshold warning is set:
        elif threshold_warn is not None:
            if len(pos_det) >= threshold_warn:
                print(warn)
                print(stream_report)
                return
        # If only threshold error is set:
        elif threshold_err is not None:
            if len(pos_det) >= threshold_err:
                print(error)
                print(stream_report)
                return

        print('Attack "%s" was not detected.' % attack_name)

    def __report_stream_attack_structured(self, pos_det, threshold_warn,
                                          threshold_err, attack_name):
        """
        Method that compares count of positive detections with thresholds and
        reports attack "stream" in case the conditions are satisfied in XML.
        :param pos_det: List of tuples with stream type and indices of streams
                        that were marked as possitive detection.
        :param threshold_warn: Value of attribute "threshold-warning".
        :param threshold_err: Value of attribute "threshold-error".
        :param attack_name: String with attack name.
        """
        error_opening = '\t<attack name="%s" type="stream" ' \
                        'detection="positive" level="error">' % attack_name
        warn_opening = '\t<attack name="%s" type="stream" ' \
                       'detection="positive" level="warning">' % attack_name
        closing = '\t</attack>'

        # If both thresholds are present in description:
        if threshold_warn is not None and threshold_err is not None:
            if len(pos_det) >= threshold_err:
                print(error_opening)
                self.__print_streams_structured(pos_det)
                print(closing)
                return
            elif len(pos_det) >= threshold_warn:
                print(warn_opening)
                self.__print_streams_structured(pos_det)
                print(closing)
                return
        # If only threshold warning is set:
        elif threshold_warn is not None:
            if len(pos_det) >= threshold_warn:
                print(warn_opening)
                self.__print_streams_structured(pos_det)
                print(closing)
                return
        # If only threshold error is set:
        elif threshold_err is not None:
            if len(pos_det) >= threshold_err:
                print(error_opening)
                self.__print_streams_structured(pos_det)
                print(closing)
                return

        print('\t<attack name="%s" type="stream" detection="negative" />'
              % attack_name)

    def __print_streams_structured(self, streams):
        """
        Function that prints streams marked as positive detection in XML format.
        :param streams: Numbers and types of streams that were marked as
                        positive detection.
        """
        for stream in streams:
            print('\t\t<stream type="%s" num="%s" />'
                  % (stream[0], str(stream[1])))

    def __report_group_attack(self, pos_det, threshold_warn, threshold_err,
                              attack_name, groups):
        """
        Method that compares count of positive detections with thresholds and
        reports attack "group" in case the conditions are satisfied.
        :param pos_det: List of group indices with positive detections.
        :param threshold_warn: Value of attribute "threshold-warning".
        :param threshold_err: Value of attribute "threshold-error".
        :param attack_name: String with name of the attack.
        :param groups: Dictionary with packets divided to groups.
        """
        error = 'Attack "%s" was detected. Level: ERROR.' % attack_name
        warn = 'Attack "%s" was detected. Level: WARNING.' % attack_name

        report_groups = []
        for i, index in enumerate(pos_det):
            packets = groups[index]
            nums = self.__get_packets_numbers(packets)
            group = 'group ' + str(i+1) + ': (' + ', '.join(nums) + ')'
            report_groups.append(group)

        report_groups = ', '.join(report_groups)
        report = 'Found %d positive detection(s) in group(s) consisting of ' \
                 'packets: %s' % (len(pos_det), report_groups)

        # If both thresholds are present in description:
        if threshold_warn is not None and threshold_err is not None:
            if len(pos_det) >= threshold_err:
                print(error)
                print(report)
                return
            elif len(pos_det) >= threshold_warn:
                print(warn)
                print(report)
                return
        # If only threshold warning is set:
        elif threshold_warn is not None:
            if len(pos_det) >= threshold_warn:
                print(warn)
                print(report)
                return
        # If only threshold error is set:
        elif threshold_err is not None:
            if len(pos_det) >= threshold_err:
                print(error)
                print(report)
                return

        print('Attack "%s" was not detected.' % attack_name)

    def __report_group_attack_structured(self, pos_det, threshold_warn,
                                         threshold_err, attack_name, groups):
        """
        Method that compares count of positive detections with thresholds and
        reports attack "group" in case the conditions are satisfied in XML.
        :param pos_det: List of group indices with positive detections.
        :param threshold_warn: Value of attribute "threshold-warning".
        :param threshold_err: Value of attribute "threshold-error".
        :param attack_name: String with name of the attack.
        :param groups: Dictionary with packets divided to groups.
        """
        error_opening = '\t<attack name="%s" type="group" ' \
                        'detection="positive" level="error">' % attack_name
        warn_opening = '\t<attack name="%s" type="group" ' \
                       'detection="positive" level="warning">' % attack_name
        closing = '\t</attack>'

        # If both thresholds are present in description:
        if threshold_warn is not None and threshold_err is not None:
            if len(pos_det) >= threshold_err:
                print(error_opening)
                self.__print_groups_structured(groups, pos_det)
                print(closing)
                return
            elif len(pos_det) >= threshold_warn:
                print(warn_opening)
                self.__print_groups_structured(groups, pos_det)
                print(closing)
                return
        # If only threshold warning is set:
        elif threshold_warn is not None:
            if len(pos_det) >= threshold_warn:
                print(warn_opening)
                self.__print_groups_structured(groups, pos_det)
                print(closing)
                return
        # If only threshold error is set:
        elif threshold_err is not None:
            if len(pos_det) >= threshold_err:
                print(error_opening)
                self.__print_groups_structured(groups, pos_det)
                print(closing)
                return

        print('\t<attack name="%s" type="group" detection="negative" />'
              % attack_name)

    def __print_groups_structured(self, groups, pos_det):
        """
        Prints groups marked as positive detection in structured XML format.
        :param groups: List of groups created before attack detection.
        :param pos_det: List of group indices with positive detection.
        """
        for index in pos_det:
            print('\t\t<group id="%s">' % index)
            packets = groups[index]
            packet_nums = self.__get_packets_numbers(packets)
            for num in packet_nums:
                print('\t\t\t<packet num="%s"/>' % num)
            print('\t\t</group>')

    def __packets_to_streams(self, packets):
        """
        Method that divides packets to TCP/UDP streams.
        :param packets: List of input packets.
        :return: List of packets divided to streams.
        """
        tcp_streams = defaultdict(list)
        udp_streams = defaultdict(list)

        for packet in packets:
            # Try to get TCP stream index:
            index = self.__get_field_value(packet, 'tcp.stream')
            if index is not None:  # If attribute was found
                tcp_streams[int(index)].append(packet)
                continue

            # Try to get UDP stream index:
            index = self.__get_field_value(packet, 'udp.stream')
            if index is not None:  # If attribute was found
                udp_streams[int(index)].append(packet)

        return tcp_streams, udp_streams

    def __get_field_value(self, packet, field_name):
        """
        Method that returns value of specified field.
        :param packet: Packet where specified field should be present.
        :param field_name: Name of the field.
        :return: Value of field or None if not present.
        """
        attribute = 'value'
        xpath = './/*[@name="%s"]' % field_name  # XPath to get specified field
        field = packet.find(xpath)
        if field is None:  # If field was not found in packet
            return None

        val = field.get(attribute)  # Get value of attribute
        return val

    def __scan_packet_groups(self, groups, desc):
        """
        Method that tries to find an attack in groups of packets.
        :param groups: Dictionary with packets divided to groups.
        :param desc: Parsed YAML description of an attack.
        :return: List of group indices with positive detection.
        """
        specs = desc['packets-specification']['specification']
        follow = desc['packets-specification']['follow']
        specified_only = desc['packets-specification']['specified-only']
        pos_det = []  # Indices of groups with positive detection of the attack

        for index, group in groups.iteritems():
            divided_packets = []
            group_tmp = list(group)
            # Add number to each packet to indicate its position in stream
            self.__tag_packets(group_tmp)
            for spec in specs:
                # Filter specified packets from groups
                p = self.__filter.process(group_tmp, spec['packet-properties'])
                # Append list of filtered packets to list of divided packets
                divided_packets.append(p)
                # Remove found packets from group_tmp so the next filtering
                # can be without already divided packets
                group_tmp = self.__remove_list_elems(group_tmp, p)

            # If attack was found:
            if self.__evaluate_attack(divided_packets, specs, follow,
                                      specified_only, group):
                pos_det.append(index)

        return pos_det

    def __evaluate_attack(self, divided_packets, specs, follow, specified_only,
                          group):
        """
        Method where packets from the group (stream) divided into groups
        according the specification are processed and its evaluated whether the
        attack is present in the group.
        :param divided_packets: List of lists where one list represents one
        group in attack description.
        :param specs: Specifications of packets from YAML description.
        :param follow: Indicates whether packets specified in the description
        has to follow each other for successful detection of the attack.
        :param specified_only: Indicates whether the packets specified in the
        description should be the only packets in entire packet group (stream).
        :param group: List of packets where attack should be detected.
        :return: True if attack is detected in group (stream), False otherwise.
        """
        specs_edit = CP.deepcopy(specs)

        # If at least one packet specification has attribute "count":
        if not self.__all_ranged(specs_edit):
            # There might be more attacks in one packet group (stream), so it is
            # necessary to edit packets ratio in specification
            is_edited = self.__edit_packets_ratio(divided_packets, specs_edit)
            if not is_edited:
                return False  # Cannot be edited, attack was not found

        # Check if count of packets is OK:
        for i, packets in enumerate(divided_packets):
            if not self.__check_packets_count(packets, specs_edit[i]):
                return False

        if specified_only:
            # Check if there are only specified packets in group:
            packets = self.__flatten_list(divided_packets)
            if len(packets) != len(group):
                return False

        if follow:
            # Check if packets follow each other:
            if not self.__check_packets_order(divided_packets[0], specs, group):
                return False

        return True

    def __check_packets_order(self, start_packets, specs, group):
        """
        Checks if packets in the attack specification are in the same order as
        in the group (stream).
        :param start_packets: List of packets from the group (stream) which
        correspond with the description of first packet in the section
        "packet-specification".
        :param specs: Parsed YAML specification of packets in stream (group).
        :param group: List of packets in the same group (stream).
        :return: True if packets follow in the same order as stated in the
        description, False otherwise.
        """
        # While there are packets that needs to be processed
        while len(start_packets) != 0:
            # Get index of the first packet in stream
            index = start_packets[0].get('index')

            group_part = []
            while index < len(group):
                packet = group[index]
                # If packet has properties stated in description
                if self.__filter.process([packet],
                                         specs[0]['packet-properties']):
                    group_part.append(packet)
                    index += 1
                    continue
                break
            if not self.__check_packets_count(group_part, specs[0]):
                start_packets.pop(0)
                continue
            start_packets = self.__remove_list_elems(start_packets, group_part)

            found = True
            for j in range(1, len(specs)):
                group_part = []
                while index < len(group):
                    packet = group[index]  # Get packet with index
                    # If packet has properties stated in description
                    if self.__filter.process([packet],
                                             specs[j]['packet-properties']):
                        group_part.append(packet)
                        index += 1
                        continue
                    break
                found = self.__check_packets_count(group_part, specs[j])
                if not found:
                    # If this is the last specification
                    if j == len(specs) - 1:
                        found = self.__check_last_sequence(group_part, specs[j])
                    break

            if found:
                return True
        return False

    def __check_last_sequence(self, sequence, spec):
        """
        Method that checks if there are more packets corresponding with the
        description of a last packet than it is stated in the description.
        :param sequence: List of packets that correspond with the last packet
        description.
        :param spec: Specification of last packet.
        :return: True if there are more packets then stated in description,
        False otherwise.
        """
        count = spec.get('count')
        if count is not None:
            # The count of packets is less then count stated in specification
            # so the sequence is incomplete
            if len(sequence) < count:
                return False
        else:
            if len(sequence) < spec['min-count']:
                return False

        return True

    def __edit_packets_ratio(self, divided_packets, specs):
        """
        Method that edits packet ratio in description to check if there is the
        same attack performed more than once in the same group (stream).
        :param divided_packets: Groups of packets that corresponds with
        specification in YAML description.
        :param specs: Parsed YAML specifications of packets.
        :return: True if ratio can be edited, False otherwise.
        """
        coefficient = 1
        for i, spec in enumerate(specs):
            if 'count' not in spec:
                continue

            # If number of packets in group is 0, coefficient would be 0
            if len(divided_packets[i]) == 0:
                continue

            # If count in specification corresponds with count of found packets
            if spec['count'] == len(divided_packets[i]):
                return True  # No need to change packet ratio

            # If count of found packets is divisible by count in specification:
            if len(divided_packets[i]) % spec['count'] == 0:
                coefficient = len(divided_packets[i]) / spec['count']
                break
            # Not divisible - coefficient can't be calculated and ratio cannot
            # be edited.
            else:
                return False

        # Edit packets ratio:
        for spec in specs:
            if 'count' in spec:
                spec['count'] *= coefficient  # Edit specific count of packets
                continue
            # Edit range of packets:
            spec['min-count'] *= coefficient
            spec['max-count'] = '*' if spec['max-count'] == '*' \
                else spec['max-count'] * coefficient

        return True

    def __all_ranged(self, specs):
        """
        Check if every specification has attributes 'min/max-count' instead of
        attribute 'count'.
        :param specs: Specifications of packets in YAML.
        :return: True if every specification in description has a pair of
        attributes 'min/max-count'.
        """
        if any('count' in s for s in specs):
            return False

        return True

    def __check_packets_count(self, packets, spec):
        """
        Method that checks if count of packets corresponds with count or range
        stated in packet description.
        :param packets: List of filtered packets.
        :param spec: Parsed YAML description of packet.
        :return: True if count of packets corresponds, False otherwise.
        """
        count = spec.get('count')
        # Attribute count is present:
        if count is not None:
            # If packets count does not correspond with expectation
            if len(packets) != count:
                return False
            # If packets count corresponds with expectation:
            return True

        # Attribute count is not present:
        min_cnt = spec['min-count']
        max_cnt = float("inf") if spec['max-count'] == '*' \
            else spec['max-count']
        # If packets count corresponds with expectation
        if min_cnt <= len(packets) <= max_cnt:
            return True
        # If packets count does not corresponds with expectation
        return False

    def __packets_to_groups(self, packets, desc):
        """
        Method that divides packets to groups in "group" attack types.
        :param packets: List of input packets.
        :param desc: Parsed YAML description of the attack.
        :return: Dictionary with groups of packets.
        """
        group_by = desc['group-by']
        groups = defaultdict(list)
        attribute = 'value'

        for packet in packets:
            for field in group_by:
                # Get field by its name:
                field = packet.find('.//*[@name="%s"]' % field['field-name'])
                if field is None:
                    continue
                # Get its specified attribute:
                key = field.get(attribute)
                if key is None:
                    continue
                groups[key].append(packet)
                break

        return groups

    def __remove_list_elems(self, rem_list, ref_list):
        """
        Method that returns the difference between rem_list and ref_list.
        :param rem_list:
        :param ref_list:
        :return:
        """
        return [it for it in rem_list if it not in ref_list]

    def __tag_packets(self, packets):
        """
        Method that assigns sequence number to each packet in the packet group
        (stream).
        :param packets: List of packets in the same group (stream).
        :return:
        """
        for i, packet in enumerate(packets):
            packet.set('index', i)

    def __flatten_list(self, lst):
        """
        Method that creates list from list of lists.
        :param lst: List that has list elements.
        :return: Flattened list.
        """
        return [i for sublist in lst for i in sublist]

    def __get_packet_number(self, packet):
        """
        Method that gets number of the packet passed as an argument.
        :param packet: Packet whose number we want to get.
        :return: Number of packet pass as an argument.
        """
        elem = packet.find('.//*[@name="num"]')
        return elem.get('show')

    def __get_packets_numbers(self, packets):
        """
        Method that gets numbers of the packets passed as an argument.
        :param packets: List of packets whose numbers we want to get.
        :return: List of packet's number.
        """
        numbers = []
        for packet in packets:
            num = self.__get_packet_number(packet)
            numbers.append(num)
        return numbers

    def __union(self, a, b):
        """
        Returns union of two lists.
        :param a: First list.
        :param b: Second lists.
        :return: Union of input lists.
        """
        return list(set(a) | set(b))

