#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Image processing and augmentation utilities. """ __author__ = "Anna Ovesná" __email__ = "xovesn03@stud.fit.vutbr.cz" import numpy as np import torch from torch.utils.data import Dataset from torchvision import datasets, transforms class ImageDataset(Dataset): def __init__(self, root_dir, img_size=80, augmentation_enabled=False, mean = None, std=None): if mean is None or std is None: # compute the statistics for normalization transform_for_stats = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), ]) temp_dataset = datasets.ImageFolder(root_dir, transform=transform_for_stats) mean = torch.zeros(3) std = torch.zeros(3) total_images = 0 for image, _ in temp_dataset: mean += image.mean(dim=[1, 2]) std += image.std(dim=[1, 2]) total_images += 1 mean /= total_images std /= total_images # save the mean and std for later use into a file with npy extension np.save("mean.npy", mean.numpy()) np.save("std.npy", std.numpy()) # apply data augmentation if enabled if augmentation_enabled: transform = transforms.Compose([ # transforms.Resize((img_size, img_size)), transforms.RandomHorizontalFlip(0.3), # transforms.RandomVerticalFlip(0.3), # doesn't work for this dataset transforms.RandomRotation(10), transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0)), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 0.5)), # getting more consistent higher accuracy without this transforms.ToTensor(), # transforms.RandomErasing(p=0.3, scale=(0.02, 0.33), ratio=(0.3, 3.3)), # getting more consistent higher accuracy without this transforms.Normalize(mean=mean, std=std) ]) else: transform = transforms.Compose([ # transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) self.dataset = datasets.ImageFolder(root_dir, transform=transform) def __len__(self): return len(self.dataset) def __getitem__(self, idx): return self.dataset[idx]