Source code for

import itertools
import math
import os
import random
from dataclasses import dataclass, field
from itertools import accumulate
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
from PIL import Image
from torch import Tensor

from simulation.utils.basics.init_options import InitOptions
from simulation.utils.basics.save_options import SaveOptions

from .base_dataset import BaseDataset
from .image_folder import find_images

[docs]@dataclass class LabeledDataset(BaseDataset, InitOptions, SaveOptions): """Dataset of images with labels.""" attributes: Optional[Sequence[str]] = None """Description of what each label means. Similar to headers in a table. """ classes: Dict[int, str] = field(default_factory=dict) """Description of what the class ids represent.""" labels: Dict[str, List[Sequence[Any]]] = field(default_factory=dict) """Collection of all labels structured as a dictionary.""" _base_path: Optional[str] = None """Path to the root of the dataset. Only needs to be set if the dataset is used to load data. """ @property def available_files(self) -> List[str]: return [ os.path.basename(file) for file in find_images(self._base_path) if os.path.exists(file) and "debug" not in file ] def __getitem__(self, index: int) -> Tuple[Union[np.ndarray, Tensor], List[Any]]: """Return an image and it's label. Args: index: Index of returned datapoint. """ key = self.available_files[index] label = self.labels.get(key, [-1]) path = os.path.join(self._base_path, key) img ="RGB") # apply image transformation img = self.transform(img) return img, label def __len__(self): """Return the total number of images in the dataset.""" return len(self.labels)
[docs] def filter_labels(self): """Remove labels that have no corresponding image.""" all_files = self.available_files self.labels = {key: label for key, label in self.labels.items() if key in all_files}
[docs] def append_label(self, key: str, label: List[Sequence[Any]]): """Add a new label to the dataset. A single image (or any abstract object) can have many labels. """ if key not in self.labels: self.labels[key] = [] self.labels[key].append(label)
[docs] def save_as_yaml(self, file_path: str): """Save the dataset to a yaml file. Override the default method to temporarily remove base_path and prevent writing it to the yaml file. Args: file_path: The output file. """ bp = self._base_path del self._base_path super().save_as_yaml(file_path) self._base_path = bp
[docs] def make_ids_continuous(self): """Reformat dataset to have continuous class ids.""" ids = sorted(self.classes.keys()) for new_id, old_id in enumerate(ids): self.replace_id(old_id, new_id)
[docs] def replace_id(self, search_id: int, replace_id: int): """Replace id (search) with another id (replace) in the whole dataset. Args: search_id: The id being searched for. replace_id: The replacement id that replaces the search ids """ # Replace in classes dict self.classes[replace_id] = self.classes.pop(search_id) # Replace in labels dict index = self.attributes.index("class_id") for label in itertools.chain(*self.labels.values()): if label[index] == search_id: label[index] = replace_id
[docs] def split(self, fractions: List[float], shuffle: bool = True) -> List["LabeledDataset"]: """Split this dataset into multiple.""" assert sum(fractions) == 1.0, "Fractions should sum up to 1" new_datasets = [ LabeledDataset(attributes=self.attributes, classes=self.classes) for _ in fractions ] labels = list(self.labels.items()) if shuffle: random.shuffle(labels) counts = (int(math.ceil(len(self) * frac)) for frac in ([0] + fractions)) indices = list(accumulate(counts)) start_indices = indices[:-1] end_indices = indices[1:] # end_indices[-1] could be bigger than len(self) # But even if it is larger, Python would make sure # that all labels are used exactly once. # So we do not need a test for it. for dataset, from_index, to_index in zip(new_datasets, start_indices, end_indices): dataset.labels = dict(labels[from_index:to_index]) return new_datasets
[docs] @classmethod def from_yaml(cls, file: str) -> "LabeledDataset": """Load a Labeled Dataset from a yaml file. Args: file: The path to the yaml file to load """ instance = cls._from_yaml(cls, file) instance._base_path = os.path.dirname(file) return instance
[docs] @classmethod def split_file( cls, file: str, parts: Dict[str, float], shuffle: bool = True ) -> List["LabeledDataset"]: """Split a dataset file into multiple datasets. Args: file: The path to the yaml file which gets split parts: A dict of names and and fractions shuffle: Split the labels randomly """ # Read dataset and split it dataset = LabeledDataset.from_yaml(file) new_datasets = dataset.split(list(parts.values()), shuffle) # Save the split yaml files for name, dataset in zip(parts.keys(), new_datasets): dataset.save_as_yaml(os.path.join(os.path.dirname(file), f"{name}.yaml")) return new_datasets
[docs] @classmethod def filter_file(cls, file: str) -> "LabeledDataset": """Filter broken file dependencies of a yaml file. Args: file: The path to the yaml file to filter """ labeled_dataset = LabeledDataset.from_yaml(file) labeled_dataset.filter_labels() labeled_dataset.save_as_yaml(file) return labeled_dataset