Source code for simulation.utils.machine_learning.cycle_gan.models.generator

from typing import List

from torch import nn

from simulation.utils.machine_learning.models.helper import get_norm_layer
from simulation.utils.machine_learning.models.resnet_generator import ResnetGenerator
from simulation.utils.machine_learning.models.unet_generator import UnetGenerator


[docs]def create_generator( input_nc: int, output_nc: int, ngf: int, netg: str, norm: str = "batch", use_dropout: bool = False, activation: nn.Module = nn.Tanh(), conv_layers_in_block: int = 2, dilations: List[int] = None, ) -> nn.Module: """Create a generator. Returns a generator Our current implementation provides two types of generators. U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) The original U-Net paper: https://arxiv.org/abs/1505.04597 Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). It uses RELU for non-linearity. Args: input_nc (int): # of input image channels: 3 for RGB and 1 for grayscale output_nc (int): # of output image channels: 3 for RGB and 1 for grayscale ngf (int): # of gen filters in the last conv layer netg (str): specify generator architecture [resnet_<ANY_INTEGER>blocks | unet_256 | unet_128] norm (str): instance normalization or batch normalization [instance | batch | none] use_dropout (bool): enable or disable dropout activation (nn.Module): Choose which activation to use. conv_layers_in_block (int): specify number of convolution layers per resnet block dilations: dilation for individual conv layers in every resnet block """ norm_layer = get_norm_layer(norm_type=norm) if "resnet" in netg: # Extract number of resnet blocks from name of netg blocks = int(netg[7:-6]) net = ResnetGenerator( input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=blocks, activation=activation, conv_layers_in_block=conv_layers_in_block, dilations=dilations, ) elif netg == "unet_128": net = UnetGenerator( input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout ) elif netg == "unet_256": net = UnetGenerator( input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout ) else: raise NotImplementedError("Generator model name [%s] is not recognized" % netg) return net