#!/usr/bin/env python3

""" Data.py: Main data-based classifier for domain classification
    v1, based on: https://www.fit.vut.cz/study/thesis/25126/.cs?year=2021
    Lot of improvements needed, just proof of concept
"""
__author__      = "Jan Polisensky"

from ast import If, Try
from classifiers.Classifier import Classifier
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import numpy as np

import json
import re
import datetime
import tensorflow as tf
import torch
from datetime import timedelta
import dateutil.parser




class Data(Classifier):

    def __init__(self):
        """
        ! Constructor of the Data-based classifier
        """
        super().__init__()
        self.name = "data"
        self.file_name = "v1.3_0.12err.pt"
        self.json_vocab = "tlds.json"
        self.ratings = self.__load_rating()
        self.external_requires = ['dns_whois']
        self.external_wants = ['ssl', 'geo']
        self.classifier_wants = []
        self.classifier_requires = []
        self.final = False
        self.data_model = self.__load_model()


    def __load_model(self) -> bool:
        """
        ! Loading data model from its binary
        """

        data_model_path = self.models_path + '/' + self.file_name
        data_model = torch.load(data_model_path)

        if not data_model:
            print("[Error]: Cant load lexical mode, using path: " + data_model_path)
            return False
        else:
            return data_model


    def __load_rating(self):
        dict_path = self.external_model_data + '/' + self.json_vocab

        try:
            with open(dict_path) as json_file:
                translated = json.load(json_file)
        except:
            return False

        return translated



    """
    ! Parse data needed by data model, a lot of information are just simply thrown away, v1 model is very simple
      TODO: in later models, use all information passed by resolvers
      PRO RADKA:
                  Jo, ten preklad je ted jednoduchy a fakt jen zhruba,
                  je to tak jak to bylo v Poliho BP. Casem to udelame poradne
                  pls nerikej mi ze je to udelany narychlo a jen tak, vim o tom. :)

      ALERT:
                  Pokud do toho nejak sahnete, dojebete model kteru je na to natrenovany
                  a uz vubec to nebude delat co ma
    """
    def __translate_input(self, domain_name, data_in):
        #  Filter used data to dictionary

        #print(data_in)
        #input()


        f_vector = list()


        # 1. Domain name and tld rating
        rating_counter=0
        tld_rating=0
        for rating in self.ratings:
            x = self.regex_cnt(domain_name, "\."+rating['name'])
            if x:
                tld_rating += float(rating['badnes'])*(-5)
                rating_counter+=1

        if rating_counter == 0:
            f_vector.append(-3)
        else:
            rating_mean = float(tld_rating/rating_counter)
            f_vector.append(rating_mean)

        # 2. Domain level
        level = domain_name.count('.')
        if level == 1:
            f_vector.append(1.0)
        elif level == 2:
            f_vector.append(0.75)
        else:
            f_vector.append(-0.25)

        # 2a. Digits in domain name are likely really bad

        ## If there is digit in domain, it is not good
        if any(c.isdigit() for c in domain_name):
            f_vector.append(-10)
        else:
            f_vector.append(0.0)

        # 3. SSL rating and ssl expiration

        # finding root-cert
        root_cert = None
        for cert in data_in['ssl_data']['certs_data']:
            if cert['is_root'] is True:
                root_cert = cert

        #print("entering data classifier with domain: " + domain_name)
        #print(type(root_cert['validity_end']))
        #print(type(data_in['whois_data']['expiration_date']))





        try:
            root_cert = None
            for cert in data_in['ssl_data']['certs_data']:
                if cert['is_root'] is True:
                    root_cert = cert

            if root_cert:
                cert_duradion =  dateutil.parser.parse(root_cert['validity_end']) - dateutil.parser.parse(root_cert['validity_start'])

                if cert_duradion > timedelta(days=300):
                    f_vector.append(5.0)
                elif cert_duradion > timedelta(days=80):
                    f_vector.append(0.6*5)
                elif cert_duradion > timedelta(days=30):
                    f_vector.append(0.3*5)
                else:
                    f_vector.append(0.1*5)

            # root_cert is no available
            else:
                f_vector.append(-5.0)

            if root_cert['organization']:
                issuer = root_cert['organization']

                if issuer == 'Google Trust Services LLC':
                    f_vector.append(5.0)
                elif issuer == 'Amazon':
                    f_vector.append(5.0)
                elif issuer == "Let's Encrypt":
                    f_vector.append(2.0)
                elif issuer == "Cloudflare, Inc.":
                    f_vector.append(2.0)
                elif issuer == 'DigiCert Inc':
                    f_vector.append(1.0)
                else:
                    f_vector.append(0.0)


        except Exception as e:
            print("[Warning]: skipping ssl translation, not clearly error, but check it: ", e)
            f_vector.append(-3.0)
            f_vector.append(0.0)

        # 3.
        # Geographical data
        try:
            if data_in['geo_data'][0]:
                coordinates = data_in['geo_data'][0]['loc'].split(",")
                f_vector.append((float(coordinates[0])/90)*5)
                f_vector.append((float(coordinates[1])/180)*5)
            else:
                f_vector.append(0)
                f_vector.append(0)

        except Exception as e:
            #print("[Warning]: skipping part of translation, not clearly error, but check it: ", e)
            f_vector.append(0)
            f_vector.append(0)

        #### TODO ####
        # 4. DNS data rating
        try:
            dns_data = data_in['dns_data']
            if dns_data['TXT'] is not None:
                f_vector.append(1.0)
            else:
                f_vector.append(0.0)


            if dns_data['MX'] is not None:
                f_vector.append(2.0)
            else:
                f_vector.append(0.0)

            if dns_data['SOA'] is not None:
                f_vector.append(1.0)
            else:
                f_vector.append(-1.0)

            if dns_data['NS'] is not None:
                f_vector.append(0.5)
            else:
                f_vector.append(-0.5)

        except Exception as e:
            print("[Warning]: skipping dns translation, not clearly error, but check it: ", e)
            f_vector.append(-5.0)
            f_vector.append(-5.0)
            f_vector.append(-5.0)
            f_vector.append(-5.0)


        # 5. Whois data
        try:
            registrar_data = data_in['whois_data']


            if str(type(data_in['whois_data']['expiration_date'])) == "<class 'datetime.datetime'>":
                end_data = data_in['whois_data']['expiration_date']
            else:
                end_data = dateutil.parser.parse(data_in['whois_data']['expiration_date'])

            if str(type(data_in['whois_data']['creation_date'])) == "<class 'datetime.datetime'>":
                start_date = data_in['whois_data']['creation_date']
            else:
                start_date =  dateutil.parser.parse(data_in['whois_data']['creation_date'])





            registration_duradion = end_data - start_date

            if registration_duradion > timedelta(days=5000):
                f_vector.append(5.0)
            elif registration_duradion > timedelta(days=3000):
                f_vector.append(0.6*5)
            elif registration_duradion > timedelta(days=1000):
                f_vector.append(0.3*5)
            else:
                f_vector.append(0.1*5)

            if registrar_data['dnssec'] is not None:
                f_vector.append(5.0)
            else:
                f_vector.append(0)


        except Exception as e:
            print("[Warning]: skipping part whois, not clearly error, but check it: ", e)
            f_vector.append(-5.0)
            f_vector.append(0.0)

        while len(f_vector) < 13:
            f_vector.append(0.0)

        if len(f_vector) != 13:
            print("[Fatal Error]: Feature vector longer than it should be, exiting now")
            return False


        return f_vector



    def __calc_accuracy_and_explanation(self, meta_calc, result, explanation):
        # {'no_ssl': no_ssl, 'no_geo': no_geo, 'raw_score': raw_score}

        accuracy = 0.8

        if meta_calc['no_ssl']:
            explanation.append('No SSL/TSL data')
            accuracy-=float(0.1)

        if meta_calc['no_geo']:
            explanation.append('No GEO data')
            accuracy-=float(0.1)

        if meta_calc['raw_score'] > 25 and result < 0.3:
            accuracy+=0.2

        if meta_calc['raw_score'] < 10 and result > 0.7:
            accuracy+=0.2

        if result < 0.7 and result > 0.3:
            accuracy-=float(0.15)
            explanation.append('Model in grey-zone, uncertain result')

        explanation.append('Raw domain score: ' + str(round(meta_calc['raw_score'], 2)))

        # Just to be sure...
        if accuracy > 1:
            accuracy = 1
        elif accuracy < 0:
            accuracy = 0




        return accuracy, explanation





    def classify(self, domain_name, internal_data, external_data):
        """
        ! Perform the classifion of the given domain_name
        @param domain_name Domain name to classify, e.g. 'fit.vut.cz'
        @param internal_data Dictionary of dependency classifiers' outputs
        @param external_data Dictionary of external inputs, e.g.
        @return Returns the classification output
        """


        ### Check model loading and model data ###
        if not self.data_model:
            return self.err_handler("Data model not loaded, check model path, using: " + self.models_path, 2)

        if not self.ratings:
            return self.err_handler("Cant load tlds rating, missing tlds.json in model_data file", 3)


        ### Parse and check input data ###
        self.parse_input_data(external_data)


        ### Load data into structure for translation ###
        merged_data = {
                       'dns_data': self.get_input_data('dns'),
                       'ssl_data': self.get_input_data('ssl'),
                       'whois_data': self.get_input_data('whois'),
                       'geo_data': self.get_input_data('geo')
                       }



        translated_vector = self.__translate_input(domain_name, merged_data)

        torch_input = torch.tensor(translated_vector)

        # Invert prediction value

        prediction = float(1) - float(self.data_model(torch_input))



        raw_score = 0
        for score in torch_input:
            raw_score+=float(score)

        raw_index = 1-((raw_score/float(25)) + 1)/float(2)

        raw_score_based_prediction = prediction

        prediction = raw_index

        if prediction > 1:
            prediction = 1
        elif prediction < 0:
            prediction = 0

        explanation = []

        ### For possible DGA use another method ###
        d_counter = 0
        for c in domain_name:
            if c.isdigit():
                d_counter+=1

        level_counter = 0
        for c in domain_name:
            if c =='.':
                level_counter+=1

        if d_counter > 3 or level_counter > 4:
            prediction = float(1) - float(self.data_model(torch_input))
            explanation.append("Data analysis not bad, but could be DGA, check similar domains manually")


        no_ssl, no_geo = True, True



        # ssl data are stored in dict
        if isinstance(self.get_input_data('ssl'), dict):
            no_ssl = False

        # geo data are stored in list, because of multiple domains
        if isinstance(self.get_input_data('geo'), list):
            no_geo = False

        meta_calc = {'no_ssl': no_ssl, 'no_geo': no_geo, 'raw_score': raw_score}

        accuracy, explanation = self.__calc_accuracy_and_explanation(meta_calc, prediction, explanation)

        explanation.append(str("Raw_score prediction: " + str(raw_score_based_prediction)))

        print("data Prediction ok: ", prediction)

        return {
                "classifier_name": self.getName(),
                "success": True,
                "error_description": '',
                "badness": float(prediction),
                "accuracy": accuracy,
                "explanation": explanation,
                "final": self.isFinal(),
                "created": datetime.datetime.now()
                }
