Source code for simulation.utils.machine_learning.data

import os
from typing import Any, Dict, List, Tuple, Union

from .data_loader import DataLoader
from .labeled_dataset import LabeledDataset
from .unlabeled_dataset import UnlabeledDataset


[docs]def load_unpaired_unlabeled_datasets( dir_a: Union[str, List[str]], dir_b: Union[str, List[str]], max_dataset_size: int, batch_size: int, sequential: bool, num_threads: int, grayscale_a: bool, grayscale_b: bool, transform_properties: Dict[str, Any], ) -> Tuple[DataLoader, DataLoader]: """Create dataloader for two unpaired and unlabeled datasets. E.g. used by cycle gan with data from two domains. Args: dir_a: path to images of domain a dir_b: path to images of domain b max_dataset_size (int): maximum amount of images to load; -1 means infinity batch_size (int): input batch size sequential (bool): if true, takes images in order, otherwise takes them randomly num_threads (int): threads for loading data grayscale_a (bool): transform domain a to gray images grayscale_b (bool): transform domain b to gray images transform_properties: dict containing properties for transforming images """ max_dataset_size = float("inf") if max_dataset_size == -1 else max_dataset_size transform_properties["grayscale"] = grayscale_a a = UnlabeledDataset(transform_properties, dir_a) transform_properties["grayscale"] = grayscale_b b = UnlabeledDataset(transform_properties, dir_b) # Transform datasets into dataloaders. a = DataLoader( dataset=a, max_dataset_size=max_dataset_size, batch_size=batch_size, num_threads=num_threads, sequential=sequential, ) b = DataLoader( dataset=b, max_dataset_size=max_dataset_size, batch_size=batch_size, num_threads=num_threads, sequential=sequential, ) return a, b
[docs]def sample_generator( dataloader: DataLoader, n_samples: int = float("inf"), ): """Generator that samples from a dataloader. Args: dataloader: Dataloader. n_samples: Number of batches of samples. """ iter_ = iter(dataloader) i = 0 while i < n_samples: i += 1 try: next_ = next(iter_) except StopIteration: iter_ = iter(dataloader) next_ = next(iter_) yield next_
[docs]def unpaired_sample_generator( dataloader_a: DataLoader, dataloader_b: DataLoader, n_samples: int = float("inf"), ): """Generator that samples pairwise from both dataloaders. Args: dataloader_a: Domain a dataloader. dataloader_b: Domain b dataloader. n_samples: Number of batches of samples. """ iter_a = iter(dataloader_a) iter_b = iter(dataloader_b) i = 0 while i < n_samples: i += 1 try: next_a = next(iter_a) except StopIteration: iter_a = iter(dataloader_a) next_a = next(iter_a) try: next_b = next(iter_b) except StopIteration: iter_b = iter(dataloader_b) next_b = next(iter_b) yield next_a, next_b
[docs]def load_labeled_dataset( label_file: str, max_dataset_size: int, batch_size: int, sequential: bool, num_threads: int, transform_properties: Dict[str, Any], ) -> DataLoader: """Create dataloader for a labeled dataset. Args: label_file: Path to a file containing all labels max_dataset_size: Maximum amount of images to load; -1 means infinity batch_size: Batch size sequential: If true, takes images in order, otherwise takes them randomly num_threads: Threads for loading data """ max_dataset_size = float("inf") if max_dataset_size == -1 else max_dataset_size dataset = LabeledDataset.from_yaml(label_file) dataset.transform_properties = transform_properties dataset._base_path = os.path.dirname(label_file) # Transform datasets into dataloaders. return DataLoader( dataset=dataset, max_dataset_size=max_dataset_size, batch_size=batch_size, num_threads=num_threads, sequential=sequential, )