Source code for simulation.utils.machine_learning.models.unet_generator

from torch import Tensor
from torch import nn as nn

from .unet_block import UnetSkipConnectionBlock


[docs]class UnetGenerator(nn.Module): """Create a Unet-based generator.""" def __init__( self, input_nc: int, output_nc: int, num_downs: int, ngf: int = 64, norm_layer: nn.Module = nn.BatchNorm2d, use_dropout: bool = False, ): """Construct a Unet generator. We construct the U-Net from the innermost layer to the outermost layer. It is a recursive process. Args: input_nc (int): the number of channels in input images output_nc (int): the number of channels in output images num_downs (int): the number of downsampling layers in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 # at the bottleneck ngf (int): the number of filters in the last conv layer norm_layer (nn.Module): normalization layer use_dropout (bool): Use dropout or not """ super().__init__() # construct unet structure unet_block = UnetSkipConnectionBlock( ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, ) # add the innermost layer for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters unet_block = UnetSkipConnectionBlock( ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout, ) # gradually reduce the number of filters from ngf * 8 to ngf unet_block = UnetSkipConnectionBlock( ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer ) unet_block = UnetSkipConnectionBlock( ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer ) unet_block = UnetSkipConnectionBlock( ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer ) self.model = UnetSkipConnectionBlock( output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer, ) # add the outermost layer
[docs] def forward(self, input: Tensor) -> Tensor: """Standard forward. Args: input (Tensor): the input tensor """ return self.model(input)