"""
I will use jackknife with the same settings as in v9 which reached 90% accuracy.
report:

Jackknife Evaluation Report:
              precision    recall  f1-score   support

           1       0.80      0.67      0.73         6
          10       0.67      0.67      0.67         6
          11       0.60      0.50      0.55         6
          12       0.80      0.67      0.73         6
          13       1.00      1.00      1.00         6
          14       0.75      1.00      0.86         6
          15       0.80      0.67      0.73         6
          16       1.00      1.00      1.00         6
          17       0.83      0.83      0.83         6
          18       0.86      1.00      0.92         6
          19       1.00      0.67      0.80         6
           2       0.83      0.83      0.83         6
          20       1.00      0.33      0.50         6
          21       1.00      0.67      0.80         6
          22       0.40      0.67      0.50         6
          23       0.67      0.33      0.44         6
          24       1.00      0.83      0.91         6
          25       1.00      0.83      0.91         6
          26       0.27      0.67      0.38         6
          27       0.67      1.00      0.80         6
          28       0.67      0.33      0.44         6
          29       0.67      1.00      0.80         6
           3       0.50      0.83      0.62         6
          30       1.00      0.33      0.50         6
          31       0.80      0.67      0.73         6
           4       0.29      0.33      0.31         6
           5       0.62      0.83      0.71         6
           6       0.50      0.33      0.40         6
           7       1.00      0.67      0.80         6
           8       0.67      0.67      0.67         6
           9       0.83      0.83      0.83         6

    accuracy                           0.70       186
   macro avg       0.76      0.70      0.70       186
weighted avg       0.76      0.70      0.70       186
"""

import os
os.environ["OMP_NUM_THREADS"] = "1"
from sklearn.mixture import GaussianMixture
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile
from python_speech_features import mfcc, logfbank, delta # pip install python_speech_features
 
from sklearn.mixture import GaussianMixture
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile
from python_speech_features import mfcc, logfbank
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

import librosa # pip install librosa

from sklearn.preprocessing import StandardScaler

TRAIN_DIR = "./Separate_data/train/sounds"
DEV_DIR = "./Separate_data/dev/sounds"
NUM_CEPS = 20
NFFT = 1024
SILENCE_TOP_DB = 25

def load_audio(src_folder):

    all_classes_mfcc_feats = []
    class_labels = []
    for person_class in sorted(os.listdir(src_folder)):
        class_dir = os.path.join(src_folder, person_class)
        
        class_mfcc_feats = []
        class_labels.append(person_class)

        for audio_record in sorted(os.listdir(class_dir)):
            audio_record_pth = os.path.join(class_dir, audio_record)
            freq_sampling, audio_sig = wavfile.read(audio_record_pth)

            audio_sig = audio_sig[20000:] # cut first 2 seconds

            mfcc_feats = mfcc(audio_sig, freq_sampling, numcep=NUM_CEPS, appendEnergy=False) # extract mfcc
            
            class_mfcc_feats.append(mfcc_feats) # add them to all mfcc for this class
        
        all_classes_mfcc_feats.append(class_mfcc_feats)
    return all_classes_mfcc_feats, class_labels

def remove_silence(audio_sig, top_db=SILENCE_TOP_DB):
    intervals = librosa.effects.split(audio_sig, top_db=top_db)
    non_silent = [audio_sig[start:end] for start, end in intervals]
    return np.concatenate(non_silent)

def get_feats(audio_sig, freq_sampling):
    audio_sig_no_silence = remove_silence(audio_sig)
    mfcc_feats = mfcc(audio_sig_no_silence, freq_sampling, numcep=NUM_CEPS, appendEnergy=False, nfft=NFFT) # extract mfcc
    mfcc_feats -= np.mean(mfcc_feats, axis=0, keepdims=True) # removes the mean of each MFCC coefficient column-wise, flattening variations due to channel/mic
    #delta_feats = delta(mfcc_feats, 2)
    #delta_delta_feats = delta(delta_feats, 2)
    #combined_feats = np.hstack((mfcc_feats, delta_feats, delta_delta_feats))
    return mfcc_feats

def compute_global_scaler(src_folder):
    all_feats = []
    for person_class in sorted(os.listdir(src_folder)):
        class_dir = os.path.join(src_folder, person_class)
        for audio_record in sorted(os.listdir(class_dir)):
            audio_record_pth = os.path.join(class_dir, audio_record)
            freq_sampling, audio_sig = wavfile.read(audio_record_pth)
            audio_sig = audio_sig[20000:]
            combined_feats = get_feats(audio_sig, freq_sampling)
            all_feats.append(combined_feats)
    all_feats = np.vstack(all_feats)
    scaler = StandardScaler()
    scaler.fit(all_feats)
    return scaler

def train_gmm(src_folder, scaler):

    class_features = {} # store features per class

    for person_class in sorted(os.listdir(src_folder)):
        class_dir = os.path.join(src_folder, person_class)
        
        class_mfcc_feats = []

        for audio_record in sorted(os.listdir(class_dir)):
            audio_record_pth = os.path.join(class_dir, audio_record)
            freq_sampling, audio_sig = wavfile.read(audio_record_pth)

            audio_sig = audio_sig[20000:] # cut first 2 seconds

            combined_feats = get_feats(audio_sig, freq_sampling)
            # add simple augumentation
            #stretched = librosa.effects.time_stretch(audio_sig, rate=0.9)
            #mfcc_feats = np.vstack([mfcc_feats, mfcc(stretched[20000:], freq_sampling)]) #, nfft=2048)])

            class_mfcc_feats.append(combined_feats) # add them to all mfcc for this class

        if class_mfcc_feats:
            class_mfcc_feats = np.vstack(class_mfcc_feats)
            print(f"\nTotal frames for {person_class}: {class_mfcc_feats.shape[0]}")
            
            # Normalize features
            normalized_feats = scaler.transform(class_mfcc_feats)
            
            # train GMM for this class
            gmm = GaussianMixture(n_components=8, 
                                covariance_type='full',
                                max_iter=500,
                                n_init=3,
                                random_state=42)
            gmm.fit(normalized_feats)
            
            # Store GMM and its scaler
            class_features[person_class] = (gmm, scaler)

    return class_features

# train GMM
#trained_models = load_audio_train_gmm(TRAIN_DIR)

# To use for classification:
#test_features = mfcc(test_audio[20000:], 16000)
#scores = {name: model.score(test_features) for name, model in trained_models.items()}
#predicted_class = max(scores.items(), key=lambda x: x[1])[0]

def evaluate_models(models, test_dir, scaler):
    """Evaluate trained GMM models on test data"""
    true_labels = []
    pred_labels = []
    
    for person_class in sorted(os.listdir(test_dir)):
        class_dir = os.path.join(test_dir, person_class)
        
        for audio_record in sorted(os.listdir(class_dir)):
            try:
                audio_record_pth = os.path.join(class_dir, audio_record)
                freq_sampling, audio_sig = wavfile.read(audio_record_pth)
                
                audio_sig = audio_sig[20000:] if len(audio_sig) > 20000 else audio_sig
                test_features = get_feats(audio_sig, freq_sampling)
                
                # Score against all models
                scores = {
                    name: gmm.score(scaler.transform(test_features))
                    for name, (gmm, scaler) in models.items()
                }
                predicted_class = max(scores.items(), key=lambda x: x[1])[0]
                
                true_labels.append(person_class)
                pred_labels.append(predicted_class)
                
            except Exception as e:
                print(f"Error processing {audio_record_pth}: {str(e)}")
                continue
    
    return true_labels, pred_labels

def plot_confusion_matrix(true_labels, pred_labels, classes):
    """Plot confusion matrix"""
    cm = confusion_matrix(true_labels, pred_labels, labels=classes)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

# Main execution
import shutil

if __name__ == "__main__":
    print("Running Jackknife Evaluation (Leave-One-Out)...")

    all_classes = sorted(os.listdir(TRAIN_DIR))
    all_results_true = []
    all_results_pred = []

    for person_class in all_classes:
        print(f"\nProcessing class {person_class}...")
        class_dir = os.path.join(TRAIN_DIR, person_class)
        audio_files = sorted(os.listdir(class_dir))

        for held_out_file in audio_files:
            # Create temporary training directory
            temp_train_dir = "./temp_train"
            if os.path.exists(temp_train_dir):
                shutil.rmtree(temp_train_dir)
            os.makedirs(temp_train_dir, exist_ok=True)

            # Copy all training files except held-out one
            for other_class in all_classes:
                src_class_dir = os.path.join(TRAIN_DIR, other_class)
                dst_class_dir = os.path.join(temp_train_dir, other_class)
                os.makedirs(dst_class_dir, exist_ok=True)

                for f in sorted(os.listdir(src_class_dir)):
                    if other_class == person_class and f == held_out_file:
                        continue
                    shutil.copy(os.path.join(src_class_dir, f), os.path.join(dst_class_dir, f))

            # Train on temp data
            scaler = compute_global_scaler(temp_train_dir)
            trained_models = train_gmm(temp_train_dir, scaler)

            # Test on held-out file
            held_out_path = os.path.join(class_dir, held_out_file)
            freq_sampling, audio_sig = wavfile.read(held_out_path)
            audio_sig = audio_sig[20000:] if len(audio_sig) > 20000 else audio_sig
            test_features = get_feats(audio_sig, freq_sampling)

            scores = {
                name: gmm.score(scaler.transform(test_features))
                for name, (gmm, scaler) in trained_models.items()
            }
            predicted_class = max(scores.items(), key=lambda x: x[1])[0]

            all_results_true.append(person_class)
            all_results_pred.append(predicted_class)

            # Clean up
            shutil.rmtree(temp_train_dir)

    # Report results
    print("\nJackknife Evaluation Report:")
    print(classification_report(all_results_true, all_results_pred, target_names=all_classes, zero_division=0))
    plot_confusion_matrix(all_results_true, all_results_pred, all_classes)
    plt.show()