import os

import numpy as np
from PIL import Image
from torch import Tensor

[docs]def tensor2im(input_image: Tensor, img_type=np.uint8, to_rgb: bool = True) -> np.ndarray: """Convert a Tensor array into a numpy image array. Args: input_image (Tensor): the input image tensor array img_type (np.integer): the desired type of the converted numpy array to_rgb (bool): translate gray image to rgb image """ image_tensor = image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array if image_numpy.shape[0] == 1 and to_rgb: # grayscale to RGB image_numpy = np.tile(image_numpy, (3, 1, 1)) image_numpy = ( (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 ) # post-processing: transpose and scaling return image_numpy.astype(img_type)
[docs]def save_image(image_numpy: np.ndarray, image_path: str, aspect_ratio: float = 1.0) -> None: """Save a numpy image to the disk. Args: image_numpy (np.ndarray): input numpy array image_path (str): the path of the image aspect_ratio (float): the aspect ratio of the resulting image """ image_pil = Image.fromarray(image_numpy) h, w, _ = image_numpy.shape if aspect_ratio > 1.0: image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) if aspect_ratio < 1.0: image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
[docs]def save_images( visuals: dict, destination: str, aspect_ratio: float = 1.0, post_fix: str = "", ) -> None: """Save images to the disk. This function will save images stored in 'visuals'. Args: destination: the folder to save the images to visuals (dict): an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs aspect_ratio (float): the aspect ratio of saved images post_fix (str): The string that extends the prefix_path """ destination = os.path.join(destination, "images") if not os.path.isdir(destination): os.makedirs(destination) for label, im_data in visuals.items(): if not os.path.isdir(os.path.join(destination, label)): os.makedirs(os.path.join(destination, label)) im = tensor2im(im_data) save_path = os.path.join(destination, label, f"{post_fix}.png") save_image(im, save_path, aspect_ratio=aspect_ratio)