Source code for simulation.utils.machine_learning.data.image_folder

"""A modified image folder class.

We modify the official PyTorch image folder
(https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
so that this class can load images from both current directory and its subdirectories.
"""

import os
import os.path
from typing import List

IMG_EXTENSIONS = [
    ".jpg",
    ".JPG",
    ".jpeg",
    ".JPEG",
    ".png",
    ".PNG",
    ".ppm",
    ".PPM",
    ".bmp",
    ".BMP",
    ".tif",
    ".TIF",
    ".tiff",
    ".TIFF",
]


[docs]def is_image_file(filename): """Check if a file is an image. Args: filename: the file name to check """ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
[docs]def find_images(dir: str, max_dataset_size=float("inf")) -> List[str]: """Recursively search for images in the given directory. Args: dir: the directory of the dataset max_dataset_size: the maximum amount of images to load """ images = [] assert os.path.isdir(dir), "%s is not a valid directory" % dir for root, _, fnames in os.walk(dir): for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname) images.append(path) return sorted(images[: min(max_dataset_size, len(images))])