Source code for simulation.utils.machine_learning.cycle_gan.ros_connector

import os
import pathlib

import cv2
import numpy as np
import torch
from PIL import Image

from simulation.utils.machine_learning.data.base_dataset import get_transform
from simulation.utils.machine_learning.data.image_operations import tensor2im
from simulation.utils.machine_learning.models import resnet_generator
from simulation.utils.machine_learning.models.helper import get_norm_layer, init_net

from .configs.test_options import CycleGANTestOptions, WassersteinCycleGANTestOptions
from .models import generator
from .models.cycle_gan_model import CycleGANModel
from .models.wcycle_gan import WassersteinCycleGANModel


[docs]class RosConnector: """Implementation of a simple ROS interface to translate simulated to "real" images.""" def __init__(self, use_wasserstein=True): """Initialize the RosConnector class. Use default test options but could be via command-line. Load and setup the model """ opt = WassersteinCycleGANTestOptions if use_wasserstein else CycleGANTestOptions opt.checkpoints_dir = os.path.join( pathlib.Path(__file__).parent.absolute(), opt.checkpoints_dir ) tf_properties = { "load_size": opt.load_size, "crop_size": opt.crop_size, "preprocess": opt.preprocess, "mask": os.path.join(os.path.dirname(__file__), opt.mask), "no_flip": True, "grayscale": True, } self.transform = get_transform(**tf_properties) self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if opt.is_wgan: netg_b_to_a = resnet_generator.ResnetGenerator( opt.input_nc, opt.output_nc, opt.ngf, get_norm_layer(opt.norm), dilations=opt.dilations, conv_layers_in_block=opt.conv_layers_in_block, ) else: netg_b_to_a = generator.create_generator( opt.input_nc, opt.output_nc, opt.ngf, opt.netg, opt.norm, not opt.no_dropout, opt.activation, opt.conv_layers_in_block, opt.dilations, ) netg_b_to_a = init_net(netg_b_to_a, opt.init_type, opt.init_gain, self.device) ModelClass = CycleGANModel if not opt.is_wgan else WassersteinCycleGANModel self.model = ModelClass.from_dict( netg_a_to_b=None, netg_b_to_a=netg_b_to_a, **opt.to_dict() ) self.model.networks.load( os.path.join(opt.checkpoints_dir, opt.name, f"{opt.epoch}_net_"), device=self.device, ) self.model.eval() def __call__(self, image: np.ndarray) -> np.ndarray: """Translate an image to a "fake real" image by using the loaded model. Args: image: Image to be translated to "fake real" Returns: Translated image. """ # Store shape h, w = image.shape # Convert to PIL image = Image.fromarray(image) # Apply transformations image: torch.Tensor = self.transform(image) image = image.to(self.device) # Copy the numpy array because it's not writeable otherwise # Bring into shape [1,1,h,w] image.unsqueeze_(0) # Inference result = self.model.networks.g_b_to_a.forward(image).detach() # From [-1,1] to [0,256] result = tensor2im(result, to_rgb=False) # Resize to the size the input image has result = cv2.resize(result, dsize=(w, h), interpolation=cv2.INTER_LINEAR) # Return as mono8 encoding return result