""" This module contains functions for training and validating a PyTorch model. """ import torch.nn as nn import torch from torch.utils.data import DataLoader, random_split from torchvision import datasets, transforms from typing import Optional, Tuple from pathlib import Path from functools import cached_property def calculate_mean_std(dataset_path: str): """Calculates the mean and standard deviation of an image dataset. Returns: tuple: A tuple containing the mean and standard deviation as PyTorch tensors for each channel (RGB). """ dataset = datasets.ImageFolder(dataset_path, transform=transforms.ToTensor()) dataloader = DataLoader(dataset, batch_size=len(dataset), shuffle=False, num_workers=2) mean = torch.zeros(3) std = torch.zeros(3) for images, _ in dataloader: batch_samples = images.size(0) # Number of images in the batch images = images.view(batch_samples, 3, -1) # Flatten each image mean += images.mean(dim=2).sum(dim=0) std += images.std(dim=2).sum(dim=0) mean /= len(dataset) std /= len(dataset) return mean, std class CustomizedDataset: def __init__(self, dataset_path: str): self.dataset_path = Path(dataset_path) assert self.dataset_path.exists(), f"Dataset path {self.dataset_path} does not exist." assert self.dataset_path.is_dir(), f"Dataset path {self.dataset_path} is not a directory." std, mean = calculate_mean_std(self.dataset_path) self._transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize(mean, std) ] ) @cached_property def full_dataset(self): return datasets.ImageFolder(self.dataset_path, transform=self._transform) def get_train_val(self) -> Tuple[datasets.ImageFolder, datasets.ImageFolder]: return random_split(self.full_dataset, [0.8, 0.2]) @cached_property def num_classes(self): return len(self.full_dataset.classes) @cached_property def class_to_idx(self): return self.full_dataset.class_to_idx @cached_property def idx_to_class(self): idx_to_class = {v:k for k, v in self.full_dataset.class_to_idx.items()} return idx_to_class def train_model( model: nn.Module, train_loader: DataLoader, optimizer: torch.optim.Optimizer, criterion: nn.Module, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, validation_loader: Optional[DataLoader] = None, device: Optional[torch.device] = None, early_stopping_patience: Optional[int] = None, epochs: int = 10, ) -> dict: """ Train the given model and return training history. """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) history = { "train_loss": [], "val_loss": [], "train_acc": [], "val_acc": [], "lr": [], } if early_stopping_patience is not None: best_val_loss = float('inf') epochs_without_improvement = 0 for epoch in range(epochs): model.train() running_loss = 0.0 correct = 0 total = 0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() _, predicted = torch.max(outputs, 1) correct += (predicted == labels).sum().item() total += labels.size(0) avg_train_loss = running_loss / len(train_loader) train_accuracy = 100 * correct / total val_loss = None val_accuracy = None if validation_loader is not None: val_loss, val_accuracy = validate_model(model, validation_loader, criterion, verbose=False) if early_stopping_patience is not None: if val_loss < best_val_loss: best_val_loss = val_loss epochs_without_improvement = 0 else: epochs_without_improvement += 1 if epochs_without_improvement >= early_stopping_patience: print(f"Early stopping triggered after {epoch+1} epochs.") break if scheduler: if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and val_loss is not None: scheduler.step(val_loss) else: scheduler.step() current_lr = optimizer.param_groups[0]['lr'] # Store history history["train_loss"].append(avg_train_loss) history["val_loss"].append(val_loss) history["train_acc"].append(train_accuracy) history["val_acc"].append(val_accuracy) history["lr"].append(current_lr) print( f"Epoch [{epoch+1}/{epochs}] | " f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.2f}% | " f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}% | " f"LR: {current_lr:.6f}" ) return history @torch.no_grad() def validate_model( model: nn.Module, validation_loader: DataLoader, criterion: Optional[nn.Module] = None, device: Optional[torch.device] = None, verbose: bool = True, ) -> Tuple[float, float]: """ Evaluate the given model on the given validation dataset. Args: model (nn.Module): The model to evaluate. validation_loader (DataLoader): The validation dataset to evaluate on. criterion (Optional[nn.Module]): The loss function to use during evaluation. Defaults to None. device (Optional[torch.device]): The device to move the model and tensor to (e.g., 'cuda' or 'cpu'). If None, will use the GPU if available. Defaults to None. verbose (bool): Whether to print the validation accuracy and loss after evaluation. Defaults to True. Returns: Tuple[float, float]: A tuple containing the average validation loss and accuracy. """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() correct = 0 total = 0 total_loss = 0.0 for images, labels in validation_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) if criterion: loss = criterion(outputs, labels) total_loss += loss.item() _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total avg_loss = total_loss / len(validation_loader) if verbose: print(f"Validation Accuracy: {accuracy:.2f}%") if criterion: print(f"Validation Loss: {avg_loss:.4f}") return avg_loss, accuracy