"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
It also includes common transformation functions (e.g., get_transform, __scale_width), which
can be later used in subclasses.
"""
import random
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, Tuple
import numpy as np
import PIL.ImageOps
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
[docs]@dataclass
class BaseDataset(data.Dataset):
"""This is the base class for other datasets."""
transform_properties: Dict[str, Any] = field(default_factory=dict)
"""Properties passed as arguments to transform generation function."""
def __len__(self):
"""Return the total number of images in the dataset."""
return -1
def __getitem__(self, index):
"""Returns the item at index.
Args:
index: the index of the item to get
"""
raise NotImplementedError()
@property
def transform(self) -> transforms.Compose:
"""transforms.Compose: Transformation that can be applied to images."""
return get_transform(**self.transform_properties)
[docs]def get_params(
preprocess: Iterable, load_size: int, crop_size: int, size: Tuple[int, int]
) -> Dict[str, Any]:
"""
Args:
preprocess: Scaling and cropping of images at load time
[resize | crop | scale_width]
load_size: Scale images to this size
crop_size: Then crop to this size
size: The image sizes
"""
w, h = size
new_h = h
new_w = w
if "resize" in preprocess and "crop" in preprocess:
new_h = new_w = load_size
elif "scale_width" in preprocess and "crop" in preprocess:
new_w = load_size
new_h = load_size * h // w
x = random.randint(0, np.maximum(0, new_w - crop_size))
y = random.randint(0, np.maximum(0, new_h - crop_size))
return {"crop_pos": (x, y)}
[docs]def __make_power_2(img, base, method=Image.BICUBIC):
"""
Args:
img: image to transform
base: the base
method: the transform method
"""
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if h == oh and w == ow:
return img
__print_size_warning(ow, oh, w, h)
return img.resize((w, h), method)
[docs]def __scale_width(img, target_size, crop_size, method=Image.BICUBIC):
"""
Args:
img: image to transform
target_size: the load size
crop_size: the crop size, which is used for training
method: the transform method
"""
ow, oh = img.size
if ow == target_size and oh >= crop_size:
return img
w = target_size
h = int(max(target_size * oh / ow, crop_size))
return img.resize((w, h), method)
[docs]def __crop(img, pos, size):
"""
Args:
img: image to transform
pos: where to crop my image
size: resulting size of cropped image
"""
ow, oh = img.size
x1, y1 = pos
tw = th = size
if ow > tw or oh > th:
return img.crop((x1, y1, x1 + tw, y1 + th))
return img
[docs]def __apply_mask(img: Image.Image, mask_file: str) -> Image.Image:
"""Overlay image with the provided mask.
Args:
img (Image.Image): image to transform
mask_file (str): path to mask image file
"""
mask = Image.open(mask_file)
# Use inverted mask as the intensity of the masking.
# This means that white parts are see through.
img.paste(mask, (0, 0), PIL.ImageOps.invert(mask))
return img
[docs]def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
[docs]def __print_size_warning(ow, oh, w, h):
"""Print warning information about image size(only print once)
Args:
ow: original width
oh: original height
w: width
h: height
"""
if not hasattr(__print_size_warning, "has_printed"):
print(
"The image size needs to be a multiple of 4. "
"The loaded image size was (%d, %d), so it was adjusted to "
"(%d, %d). This adjustment will be done to all images "
"whose sizes are not multiples of 4" % (ow, oh, w, h)
)
__print_size_warning.has_printed = True