simulation.utils.machine_learning.models.unet_block module

Summary

Classes:

UnetSkipConnectionBlock

Defines the Unet submodule with skip connection.

Reference

class UnetSkipConnectionBlock(outer_nc: int, inner_nc: int, input_nc: Optional[int] = None, submodule: Optional[torch.nn.modules.module.Module] = None, outermost: bool = False, innermost: bool = False, norm_layer: torch.nn.modules.module.Module = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>, use_dropout: bool = False)[source]

Bases: torch.nn.modules.module.Module

Defines the Unet submodule with skip connection. X.

——————-identity———————- |-- downsampling -- |submodule| – upsampling –|

forward(x: torch.Tensor) → torch.Tensor[source]

Forward with skip connection, if this is not the outermost.

Parameters

x (torch.Tensor) – the input tensor

training: bool