Source code for simulation.utils.machine_learning.cycle_gan.image_translator

import argparse
import os
import pathlib

import cv2
import numpy as np
import torch
from PIL import Image

from import get_transform
from import tensor2im
from simulation.utils.machine_learning.models import resnet_generator
from simulation.utils.machine_learning.models.helper import get_norm_layer, init_net

from .configs.test_options import CycleGANTestOptions, WassersteinCycleGANTestOptions
from .models import generator
from .models.cycle_gan_model import CycleGANModel
from .models.wcycle_gan import WassersteinCycleGANModel

[docs]class ImageTranslator: """Implementation of a simple ROS interface to translate simulated to "real" images.""" def __init__(self, use_wasserstein=True): """Initialize the ImageTranslator class. Use default test options but could be via command-line. Load and setup the model """ opt = WassersteinCycleGANTestOptions if use_wasserstein else CycleGANTestOptions opt.checkpoints_dir = os.path.join( pathlib.Path(__file__).parent.absolute(), opt.checkpoints_dir ) tf_properties = { "load_size": opt.load_size, "crop_size": opt.crop_size, "preprocess": opt.preprocess, "mask": os.path.join(os.path.dirname(__file__), opt.mask), "no_flip": True, "grayscale": True, } self.transform = get_transform(**tf_properties) self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if opt.is_wgan: netg_b_to_a = resnet_generator.ResnetGenerator( opt.input_nc, opt.output_nc, opt.ngf, get_norm_layer(opt.norm), dilations=opt.dilations, conv_layers_in_block=opt.conv_layers_in_block, ) else: netg_b_to_a = generator.create_generator( opt.input_nc, opt.output_nc, opt.ngf, opt.netg, opt.norm, not opt.no_dropout, opt.activation, opt.conv_layers_in_block, opt.dilations, ) netg_b_to_a = init_net(netg_b_to_a, opt.init_type, opt.init_gain, self.device) ModelClass = CycleGANModel if not opt.is_wgan else WassersteinCycleGANModel self.model = ModelClass.from_dict( netg_a_to_b=None, netg_b_to_a=netg_b_to_a, **opt.to_dict() ) self.model.networks.load( os.path.join(opt.checkpoints_dir,, f"{opt.epoch}_net_"), device=self.device, ) self.model.eval() def __call__( self, image: np.ndarray, f_keep_pixels: float = 0, f_keep_colored_pixels: float = 0, ) -> np.ndarray: """Translate an image to a "fake real" image by using the loaded model. Args: image: Image to be translated to "fake real" f_keep_pixels: Factor of original pixels that are kept f_keep_colored_pixels: Factor of colored pixels that are kept Returns: Translated image. """ # Store shape h, w, c = image.shape img_np = image # Apply transformations image: torch.Tensor = self.transform(Image.fromarray(image)) image = # Copy the numpy array because it's not writeable otherwise # Bring into shape [1,1,h,w] image.unsqueeze_(0) # Inference result = self.model.networks.g_b_to_a.forward(image).detach() # From [-1,1] to [0,256] result = tensor2im(result, to_rgb=False) # Resize to the size the input image has result = cv2.resize(result, dsize=(w, h), interpolation=cv2.INTER_LINEAR) if f_keep_pixels > 0: grey_img = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) colored_pxls = f_keep_pixels * np.ones((h, w)) result = (1 - f_keep_pixels) * result + f_keep_pixels * grey_img if f_keep_colored_pixels > 0: grey_img = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) colored_pxls = f_keep_colored_pixels * np.ones((h, w)) colored_pxls[img_np[:, :, 0] == img_np[:, :, 1]] = 0 result = ( np.ones_like(colored_pxls) - colored_pxls ) * result + colored_pxls * grey_img return result.astype(np.uint8)
if __name__ == "__main__": """Run GAN over all files in folder.""" parser = argparse.ArgumentParser(description="Extract images from a ROS bag.") parser.add_argument("--input_dir", help="Directory with input images.") parser.add_argument("--output_dir", help="Directory for output images.") parser.add_argument( "--gan_type", type=str, default="default", help="Decide whether to use Wasserstein gan or default gan [default, wgan]", ) args = parser.parse_args() GAN = ImageTranslator(args.gan_type) files = [ file for file in os.listdir(args.input_dir) if os.path.isfile(os.path.join(args.input_dir, file)) and file.lower().endswith((".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".gif")) ] os.makedirs(args.output_dir, exist_ok=True) for i, file in enumerate(files): input_file_path = os.path.join(args.input_dir, file) output_file_path = os.path.join(args.output_dir, file) img_np = np.array( img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2BGR) translated_image = GAN(img_np) cv2.imwrite(output_file_path, translated_image) print(f"Processing: {100 * i / len(files):.2f}%")