Source code for simulation.utils.machine_learning.data.unlabeled_dataset

from dataclasses import dataclass, field
from typing import List, Tuple, Union

from PIL import Image

from .base_dataset import BaseDataset
from .image_folder import find_images


[docs]@dataclass class UnlabeledDataset(BaseDataset): """This dataset class can load a set of unlabeled data.""" folder_path: Union[str, List[str]] = field(default_factory=list) """Path[s] to folders that contain the data.""" def __post_init__(self): self.file_paths = self.load_file_paths()
[docs] def load_file_paths(self) -> List[str]: """List[str]: File paths to all data.""" if isinstance(self.folder_path, list): data = sum((find_images(d) for d in self.folder_path), []) else: data = find_images(self.folder_path) return data
def __getitem__(self, index) -> Tuple[Image.Image, str]: """Return a data point and its metadata information. Args: index: Index of returned datapoint. Returns: Datapoint + path to its file """ path = self.file_paths[ index % self.__len__() ] # make sure index is within then range img = Image.open(path).convert("RGB") # apply image transformation img = self.transform(img) return img, path def __len__(self): """Return the total number of images in the dataset.""" return len(self.file_paths)