#!/usr/bin/env python3

""" Data.py: Main data-based classifier for domain classification
    v1, based on: https://www.fit.vut.cz/study/thesis/25126/.cs?year=2021
    Lot of improvements needed, just proof of concept 
"""
__author__      = "Jan Polisensky"



from classifiers.Classifier import Classifier
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import numpy as np

import json
import re
import datetime
import tensorflow as tf
import torch
from datetime import timedelta
import pickle
import dateutil.parser


class SVM(Classifier):
    def __init__(self):
        """
        ! Constructor of the Data-based classifier
        """
        super().__init__()
        self.name = "svm"
        self.file_name = "svm_final.svm"
        self.external_requires = ['dns_whois']
        self.json_vocab = "tlds.json"
        self.ratings = self.__load_rating()
        self.external_wants = ['ssl', 'geo']
        self.classifier_deps = []
        self.final = False
        self.data_model = self.__load_model()


    def __load_model(self) -> bool:
        """
        ! Loading data model from its binary
        """

        data_model_path = self.models_path + '/' + self.file_name
        data_model = pickle.load(open(data_model_path, 'rb'))

        if not data_model:
            print("[Error]: Cant load svm model, using path: " + data_model_path)
            return False
        else:
            return data_model
        
    def __load_rating(self):
        dict_path = self.external_model_data + '/' + self.json_vocab

        try:
            with open(dict_path) as json_file:
                translated = json.load(json_file)
        except:
            return False

        return translated
    
    def __translate_input(self, domain_name, data_in):


        f_vector = list()


        # 1. Domain name and tld rating
        rating_counter=0
        tld_rating=0
        for rating in self.ratings:
            x = self.regex_cnt(domain_name, "\."+rating['name'])
            if x:
                tld_rating += float(rating['badnes'])*(-5)
                rating_counter+=1

        if rating_counter == 0:
            f_vector.append(-3)
        else:
            rating_mean = float(tld_rating/rating_counter)
            f_vector.append(rating_mean)

        # 2. Domain level         
        level = domain_name.count('.')
        if level == 1:
            f_vector.append(1.0)
        elif level == 2:
            f_vector.append(0.75)
        else:
            f_vector.append(-0.25)

        # 2a. Digits in domain name are likely really bad

        ## If there is digit in domain, it is not good
        if any(c.isdigit() for c in domain_name):
            f_vector.append(-10)
        else:
            f_vector.append(0.0)

        # 3. SSL rating and ssl expiration

        # finding root-cert



        try:
            root_cert = None
            for cert in data_in['ssl_data']['certs_data']:
                if cert['is_root'] is True:
                    root_cert = cert
                    
            if root_cert:
                cert_duradion =  dateutil.parser.parse(root_cert['validity_end']) - dateutil.parser.parse(root_cert['validity_start'])

                if cert_duradion > timedelta(days=300):
                    f_vector.append(5.0)
                elif cert_duradion > timedelta(days=80):
                    f_vector.append(0.6*5)
                elif cert_duradion > timedelta(days=30):
                    f_vector.append(0.3*5)
                else:
                    f_vector.append(0.1*5)

            # root_cert is no available
            else:
                f_vector.append(-5.0)

            if root_cert['organization']:
                issuer = root_cert['organization']

                if issuer == 'Google Trust Services LLC':
                    f_vector.append(5.0)
                elif issuer == 'Amazon':
                    f_vector.append(5.0)
                elif issuer == "Let's Encrypt":
                    f_vector.append(2.0)
                elif issuer == "Cloudflare, Inc.":
                    f_vector.append(2.0)
                elif issuer == 'DigiCert Inc':
                    f_vector.append(1.0)
                else:
                    f_vector.append(0.0)
                    

        except Exception as e:
            print("[Warning]: skipping ssl translation, not clearly error, but check it: ", e)
            f_vector.append(-5.0)
            f_vector.append(0.0)

        # 3.
        # Geographical data
        try:
            if data_in['geo_data'][0]:
                coordinates = data_in['geo_data'][0]['loc'].split(",")
                f_vector.append((float(coordinates[0])/90)*5)
                f_vector.append((float(coordinates[1])/180)*5)
            else:
                f_vector.append(0)
                f_vector.append(0)

        except Exception as e:
            #print("[Warning]: skipping part of translation, not clearly error, but check it: ", e)
            f_vector.append(0)
            f_vector.append(0)
            
        # 4. DNS data rating
        try:
            dns_data = data_in['dns_data']
            if dns_data['TXT'] is not None:
                f_vector.append(1.0)
            else:
                f_vector.append(0.0)


            if dns_data['MX'] is not None:
                f_vector.append(2.0)
            else:
                f_vector.append(0.0)

            if dns_data['SOA'] is not None:
                f_vector.append(1.0)
            else:
                f_vector.append(-1.0)

            if dns_data['NS'] is not None:
                f_vector.append(0.5)
            else:
                f_vector.append(-0.5)
                
        except Exception as e:
            print("[Warning]: skipping dns translation, not clearly error, but check it: ", e)
            f_vector.append(-5.0)
            f_vector.append(-5.0)
            f_vector.append(-5.0)
            f_vector.append(-5.0)


        # 5. Whois data
        try:
            registrar_data = data_in['whois_data']

            start_date = None
            end_data = None
            

            if str(type(data_in['whois_data']['expiration_date'])) == "<class 'datetime.datetime'>":
                end_data = data_in['whois_data']['expiration_date']
            else:
                end_data = dateutil.parser.parse(data_in['whois_data']['expiration_date'])

            if str(type(data_in['whois_data']['creation_date'])) == "<class 'datetime.datetime'>":
                start_date = data_in['whois_data']['creation_date']
            else:
                start_date =  dateutil.parser.parse(data_in['whois_data']['creation_date'])
                


            registration_duradion = end_data - start_date

            if registration_duradion > timedelta(days=5000):
                f_vector.append(5.0)
            elif registration_duradion > timedelta(days=3000):
                f_vector.append(0.6*5)
            elif registration_duradion > timedelta(days=1000):
                f_vector.append(0.3*5)
            else:
                f_vector.append(0.1*5)

            if registrar_data['dnssec'] is not None:
                f_vector.append(5.0)
            else:
                f_vector.append(0)


        except Exception as e:
            print("[Warning]: skipping part whois, not clearly error, but check it: ", e)
            f_vector.append(-5.0)
            f_vector.append(0.0)

        while len(f_vector) < 13:
            f_vector.append(0.0)

        if len(f_vector) != 13:
            print("[Fatal Error]: Feature vector longer than it should be, exiting now")
            return False
        
        return f_vector


    def classify(self, domain_name, internal_data, external_data):
        """
        ! Perform the classifion of the given domain_name
        @param domain_name Domain name to classify, e.g. 'fit.vut.cz'
        @param internal_data Dictionary of dependency classifiers' outputs
        @param external_data Dictionary of external inputs, e.g.
        @return Returns the classification output
        """
        
        
        ### Check model loading and model data ###
        if not self.data_model:
            return self.err_handler("Data model not loaded, check model path, using: " + self.models_path, 2)
        
        if not self.ratings:
            return self.err_handler("Cant load tlds rating, missing tlds.json in model_data file", 3)
        
        ### Parse and check input data ###
        self.parse_input_data(external_data)


        ### Load data into structure for translation ###
        merged_data = {
                       'dns_data': self.get_input_data('dns'), 
                       'ssl_data': self.get_input_data('ssl'), 
                       'whois_data': self.get_input_data('whois'), 
                       'geo_data': self.get_input_data('geo')
                       }
        
        
        translated_vector = self.__translate_input(domain_name, merged_data)
        np_input = np.array([translated_vector], dtype=np.float32)
        
    
        prediction = float(1) - self.data_model.predict(np_input)[0]  
        
        raw_score = 0
        for score in translated_vector:
            raw_score+=float(score)
            
        raw_index = 1-((raw_score/float(25)) + 1)/float(2)

        
        if raw_index < 0.4:
            raw_index = prediction*0.2+raw_index*0.8
        else:
            raw_index = prediction*0.8+raw_index*0.2
            
            

    
        return {
                "classifier_name": self.getName(),
                "success": True,
                "error_description": '',
                "badness": float(raw_index),
                "accuracy": 1,
                "explanation": [],
                "final": self.isFinal(),
                "created": datetime.datetime.now()
                }