import os import torch from audio_model import AudioModel from image_model import MyCNN from dataParser import AudioDataset, ImageDataset, audio_cutter, get_normalized from torch.utils.data import DataLoader from project import create_my_best_image_model def evaluate_models(eval_dir, audio_model_path, image_model_path): """ Evaluate the best audio and image models on the eval dataset and save results. :param eval_dir: Directory containing evaluation data. :param audio_model_path: Path to the saved audio model. :param image_model_path: Path to the saved image model. """ # Load the best models print("Loading the best models...") audio_model = AudioModel(path = audio_model_path) image_model = MyCNN(path = image_model_path) image_model.eval() # Prepare evaluation datasets print("Loading evaluation datasets...") audio_cutter(eval_dir, 1.5) audio_data = AudioDataset([eval_dir], one_hot=False, unlabeled=True) image_data = ImageDataset([eval_dir], transform=get_normalized, one_hot=False, unlabeled=True) # Open the output file output_file = "audio_GMM" with open(output_file, "w") as f: print("Evaluating audio and image models...") # Evaluate audio data audio_out = {} for idx, (audio_features, label) in enumerate(audio_data): predicted_label, log_probs = audio_model.predict([audio_features]) log_probs = [f"{score:.6f}" for score in log_probs] audio_out[label] = f"{label} {predicted_label.item()+1} {' '.join(log_probs)}\n" for key in sorted(audio_out): f.write(audio_out[key]) output_file = "image_CNN" with open(output_file, "w") as f: # Evaluate image data one by one image_out = {} for i in range(len(image_data.image_data)): image_features, label = image_data[i] image_features = image_features.unsqueeze(0) # add batch dimension with torch.no_grad(): logits = image_model(image_features) predicted_label = torch.argmax(logits, dim=1).item() + 1 log_probs = logits.squeeze(0).tolist() log_probs = [f"{score:.6f}" for score in log_probs] image_out[label] = f"{label} {predicted_label} {' '.join(log_probs)}\n" for key in sorted(image_out): f.write(image_out[key]) print(f"Datasets evaluated, results saved") # Example usage if __name__ == "__main__": eval_dir = "SUR_projekt2024-2025_eval" audio_model_path = "best_gmm_audio.pickle" image_model_path = "CNN_image.pickle" evaluate_models(eval_dir, audio_model_path, image_model_path)