Source code for simulation.utils.machine_learning.models.resnet_block

from typing import List, Type

from torch import Tensor, nn


[docs]class ResnetBlock(nn.Module): """Define a Resnet block.""" def __init__( self, dim: int, padding_type: str, norm_layer: Type[nn.Module], use_dropout: bool, use_bias: bool, n_conv_layers: int = 2, dilations: List[int] = None, ): """Initialize the Resnet block. A resnet block is a conv block with skip connections. We implement skip connections in <forward> function. Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf Args: dim: Number of channels in the conv layer padding_type: Name of padding layer: reflect | replicate | zero norm_layer: Normalization Layer class use_dropout: Whether to use dropout layers use_bias: Whether the conv layer uses bias or not n_conv_layers: Number of convolution layers in this block dilations: List of dilations for each conv layer """ super().__init__() if dilations is None: dilations = [1 for _ in range(n_conv_layers)] assert n_conv_layers == len( dilations ), "There must be exactly one dilation value for each conv layer." conv_block = [] for dilation in dilations: padding = 0 if padding_type == "reflect": conv_block += [nn.ReflectionPad2d(dilation)] elif padding_type == "replicate": conv_block += [nn.ReplicationPad2d(dilation)] elif padding_type == "zero": padding = dilation else: raise NotImplementedError("padding [%s] is not implemented" % padding_type) conv_block += [ nn.Conv2d( dim, dim, kernel_size=3, padding=padding, dilation=dilation, bias=use_bias, ), norm_layer(dim), nn.ReLU(True), ] if use_dropout: conv_block += [nn.Dropout(0.5)] if use_dropout: # The last dropout layer should not be there del conv_block[-1] self.conv_block = nn.Sequential(*conv_block)
[docs] def forward(self, x: Tensor) -> Tensor: """Standard forward with skip connection. Args: x: Input tensor """ out = x + self.conv_block(x) # add skip connections return out