# SUR - Project # May 2025 # Michal Zobaník (xzoban01), Radek Zobaník (xzoban02) # Main file for running training and classiffication of the models import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # remove tensorflow logging import argparse from result import Result from audio import Audio_GMM from image_cnn import Image_cnn def save_result(results: list[Result], save_file_location): with open(save_file_location, "w") as file: for result in results: if not len(result.probabs) == 31: print(f"Too many probabs in result {len(result.probabs)}") exit(1) if result.chosen_class < 1 or result.chosen_class > 31: print(f"Class number is out of bounds: {result.chosen_class}") exit(1) file.write(f"{result.file_name} {result.chosen_class} {' '.join(str(item) for item in result.probabs)}\n") if __name__ == '__main__': # parse args parser = argparse.ArgumentParser( prog="SUR-classification", description="Program for using images and audio for classification") parser.add_argument("-m", "--model", choices=["a", "i", "b"], default="b", help="What models are used individually for classification. a - audio, i - images, b - both") parser.add_argument("-t", "--train", action="store_true", default=False, help="Train models") parser.add_argument("-c", "--classification", help="Folder with data for classification") parser.add_argument("--audio_model_location", default="voice_gmm.pkl", help="Location of the audio model or where it will be stored") parser.add_argument("--image_model_location", default="image_cnn_model.keras", help="Location of the image model or where it will be stored") parser.add_argument("--train_data", default="dataset/train", help="Location of data used for training of the models") parser.add_argument("--eval_data", default="dataset/dev", help="Location of data used for evaluation of the models") args = parser.parse_args() #set used models models = [] match args.model: case "a": models.append(Audio_GMM(model_location=args.audio_model_location)) case "i": models.append(Image_cnn(model_location=args.image_model_location)) case "b": models.append(Audio_GMM(model_location=args.audio_model_location)) models.append(Image_cnn(model_location=args.image_model_location)) # train models if args.train: for model in models: if os.path.exists(model.model_location): print(f"File with model {model.model_location} already exists. Delete it or rename it and run the program again.") exit(1) model.train(args.train_data, args.eval_data) # classification if args.classification: for model in models: result = model.classification(args.classification) save_result(result, model.result_file_name)