import os
import numpy as np
from glob import glob
from python_speech_features import mfcc
from scipy.io import wavfile
import scipy.stats as stats
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import MinMaxScaler
from math import log
import argparse

def wav16khz2mfcc(dir_name, skip_sec=0):
    """
    Extracts MFCC features from WAV files in a directory.
    Args:
        dir_name: Directory containing WAV files.
        skip_sec: Number of seconds to skip at the beginning of each wav.
    Returns:
        dict: Dictionary with file names as keys and MFCC features as values.
    """
    features = {}
    for f in glob(dir_name + '/*.wav'):
        rate , s = wavfile.read(f)
        s = s[skip_sec * rate:]
        mfcc_features = mfcc(s, samplerate=rate, winlen=0.025, winstep=0.01, numcep=36, nfilt=36, nfft=1024)
        features[f] = mfcc_features
    return features

def prepare_wav(data):
    """
    Computes mean, std, skewness, and kurtosis of MFCC to be used as features to train the SVM. Extracts labels from the file names.
    Args:
        data: Dictionary with file names as keys and MFCC features as values.
    Returns:
        tuple: Tuple containing features and labels.
    """
    features = []
    labels = []
    for key, wav in data.items():
        try:
            label = int(key.split('/')[-1].split('_')[0][2:4])
        except (ValueError, IndexError):
            label = 0
        mfcc_mean = np.mean(wav, axis=0)
        mfcc_std = np.std(wav, axis=0)
        mfcc_skew = stats.skew(wav, axis=0)
        mfcc_kurtosis = stats.kurtosis(wav, axis=0)
        
        mfcc_combined = np.hstack((mfcc_mean, mfcc_std, mfcc_skew, mfcc_kurtosis))
        
        features.append(mfcc_combined)
        labels.append(label)
    
    features = np.array(features)
    labels = np.array(labels)

    unique_labels = sorted(set(labels))
    label_mapping = {orig: idx for idx, orig in enumerate(unique_labels)}
    labels = np.array([label_mapping[l] for l in labels])
    
    return features, labels

def store_results(output_file, vdata, val_preds, val_probs):
    """
    Stores the results of the SVM predictions in a file.
    Args:
        output_file: File to save the predictions.
        vdata: Dictionary with file names as keys and MFCC features as values.
        val_preds: List of predicted labels.
        val_probs: List of predicted probabilities.
    """
    with open(output_file, "w") as f:
        for idx, key in enumerate(vdata.keys()):
            file_name = key.split('/')[-1].split('.')[0]
            hard_prediction = val_preds[idx]
            probas = np.log(val_probs[idx] + 1e-10)
            line = f"{file_name} {hard_prediction + 1} " + " ".join([f"{proba:.2f}" for proba in probas]) + "\n"
            f.write(line)

def collect_data(path):
    """
    Collects data from a directory and its subdirectories.
    Args:
        path: Path to the directory.
    Returns:
        dict: Dictionary with file names as keys and MFCC features as values.
    """
    data = {}
    for root, _, files in os.walk(path):
        if any(f.lower().endswith('.wav') for f in files):
            data.update(wav16khz2mfcc(root))
    return data

def svm_predict(vdata_path, tdata_path, output_file="audio_SVM_results.csv"):
    """
    Trains and evaluates an SVM model. Results are saved to a file.
    Args:
        vdata_path: Path to validation data directory.
        tdata_path: Path to training data directory.
        target_file: File to save the predictions.
    """
    vdata = {}
    tdata = {}

    vdata = collect_data(vdata_path)
    tdata = collect_data(tdata_path)

    train_features, train_labels = prepare_wav(tdata)
    val_features, val_labels = prepare_wav(vdata)

    scaler = MinMaxScaler()
    X_train = scaler.fit_transform(train_features)
    X_val = scaler.transform(val_features)


    svm_model = SVC(kernel='rbf', C=10.0, gamma='scale', probability=True)
    svm_model.fit(X_train, train_labels)

    val_preds = svm_model.predict(X_val)
    val_probs = svm_model.predict_proba(X_val)

    store_results(output_file, vdata, val_preds, val_probs)

    acc = accuracy_score(val_labels, val_preds)
    #print(f"Validation Accuracy: {acc*100:.2f}%")

def main():
    parser = argparse.ArgumentParser(description="Run SVM prediction on audio data.")
    parser.add_argument("-v", "--vdata_path", type=str, required=True, help="Path to validation data directory.")
    parser.add_argument("-t", "--tdata_path", type=str, required=True, help="Path to training data directory.")
    parser.add_argument("-o", "--output_file", type=str, default="audio_SVM_results.csv", help="File to save the predictions.")
    args = parser.parse_args()

    vdata_path = args.vdata_path
    tdata_path = args.tdata_path
    output_file = args.output_file

    svm_predict(vdata_path, tdata_path, output_file)

if __name__ == "__main__":
    main()