import numpy as np from image_CNN import cnn_predict from audio_SVM import svm_predict import argparse def load_predictions(file_path): """ Loads outputs of CNN and SVM models. Args: file_path: Path to the file containing predictions. Returns: tuple: Tuple containing file names, hard predictions, and log probabilities.""" file_names = [] hard_preds = [] log_probs = [] with open(file_path, "r") as f: for line in f: parts = line.strip().split() file_names.append(parts[0]) hard_preds.append(int(parts[1])) log_probs.append([float(x) for x in parts[2:]]) return file_names, np.array(hard_preds), np.array(log_probs) def get_labels_from_filenames(files): """ Gets labels from file names. Args: files: List of file names. Returns: list: List of labels. """ labels = [] for file in files: try: label = int(file.split('/')[-1].split('_')[0][2:4]) except (ValueError, IndexError): label = 0 labels.append(label) unique_labels = sorted(set(labels)) label_mapping = {orig: idx for idx, orig in enumerate(unique_labels)} labels = np.array([label_mapping[l] for l in labels]) return labels+1 def calculate_accuracy(predictions, true_labels): """ Calculaets accuracy of the hybrid model" Args: predictions: Predictions of the model. true_labels: True values. Returns: float: Accuracy of the model. """ correct = 0 total = len(predictions) for pred, true_label in zip(predictions, true_labels): if pred == true_label: correct += 1 accuracy = (correct / total) * 100 return accuracy def normalize_log_probs(log_probs): """ Normalizes log probabilities. Args: log_probs: Log probabilities to be normalized. Returns: np.ndarray: Normalized log probabilities. """ mean_vals = log_probs.mean(axis=0) std_vals = log_probs.std(axis=0) norm_probs = (log_probs - mean_vals) / (std_vals + 1e-10) return norm_probs def combine_predictions(cnn_file, svm_file, output_file): """ Combines predictions from CNN and SVM models. Args: cnn_file: Path to the CNN predictions file. svm_file: Path to the SVM predictions file. output_file: Path of the output file. """ cnn_files, _, cnn_log_probs = load_predictions(cnn_file) svm_files, _, svm_log_probs = load_predictions(svm_file) cnn_sorted = sorted(zip(cnn_files, cnn_log_probs), key=lambda x: x[0]) svm_sorted = sorted(zip(svm_files, svm_log_probs), key=lambda x: x[0]) cnn_files_sorted, cnn_log_probs_sorted = zip(*cnn_sorted) svm_files_sorted, svm_log_probs_sorted = zip(*svm_sorted) true_labels = get_labels_from_filenames(cnn_files_sorted) cnn_norm = normalize_log_probs(np.array(cnn_log_probs_sorted)) svm_norm = normalize_log_probs(np.array(svm_log_probs_sorted)) combined = 0.32*cnn_norm + 0.85*svm_norm combined_hard_preds = np.argmax(combined, axis=1) + 1 #accuracy = calculate_accuracy(combined_hard_preds, true_labels) #print(f"Validation Accuracy: {accuracy:.2f}%") with open(output_file, "w") as f: for name, pred, _ in zip(cnn_files_sorted, combined_hard_preds, combined): line = f"{name} {pred} " + " NaN" + "\n" f.write(line) def main(): parser = argparse.ArgumentParser(description="Spuštění kombinovaných predikcí pomocí CNN a SVM.") parser.add_argument("-v", "--vdata_path", type=str, required=True, help="Cesta k adresáři validačních dat.") parser.add_argument("-t", "--tdata_path", type=str, required=True, help="Cesta k adresáři tréninkových dat.") parser.add_argument("-o", "--output_file", type=str, default="hybrid_classifier_results.csv", help="Soubor pro uložení predikcí.") args = parser.parse_args() vdata_path = args.vdata_path tdata_path = args.tdata_path output_file = args.output_file cnn_predict(vdata_path, tdata_path) svm_predict(vdata_path, tdata_path) combine_predictions("image_CNN_results.csv", "audio_SVM_results.csv", output_file) if __name__ == "__main__": main()