simulation.utils.machine_learning.models.unet_generator module

Summary

Classes:

UnetGenerator

Create a Unet-based generator.

Reference

class UnetGenerator(input_nc: int, output_nc: int, num_downs: int, ngf: int = 64, norm_layer: torch.nn.modules.module.Module = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>, use_dropout: bool = False)[source]

Bases: torch.nn.modules.module.Module

Create a Unet-based generator.

forward(input: torch.Tensor) → torch.Tensor[source]

Standard forward.

Parameters

input (Tensor) – the input tensor

training: bool