Source code for simulation.utils.machine_learning.cycle_gan.models.n_layer_discriminator

import functools

from torch import Tensor
from torch import nn as nn


[docs]class NLayerDiscriminator(nn.Module): """Defines a PatchGAN discriminator.""" def __init__( self, input_nc: int, ndf: int = 64, n_layers: int = 3, norm_layer: nn.Module = nn.BatchNorm2d, use_sigmoid: bool = True, is_quadratic: bool = True, ): """Construct a PatchGAN discriminator. Args: input_nc (int): the number of channels in input images ndf (int): the number of filters in the last conv layer n_layers (int): the number of conv layers in the discriminator norm_layer (nn.Module): normalization layer use_sigmoid (bool): sigmoid activation at the end """ super().__init__() if ( type(norm_layer) == functools.partial ): # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func == nn.BatchNorm2d else: use_bias = norm_layer == nn.BatchNorm2d kw = 4 padding_width = 1 padding_first_layer = ( padding_width if is_quadratic else (2 * padding_width, padding_width) ) stride_first_layer = 2 if is_quadratic else (1, 2) sequence = [ nn.Conv2d( input_nc, ndf, kernel_size=kw, stride=stride_first_layer, padding=padding_first_layer, ), nn.LeakyReLU(0.2, True), ] num_filters_multiplier = 1 for n in range(1, n_layers): # gradually increase the number of filters num_filters_multiplier_prev = num_filters_multiplier num_filters_multiplier = min(2 ** n, 8) sequence += [ nn.Conv2d( ndf * num_filters_multiplier_prev, ndf * num_filters_multiplier, kernel_size=kw, stride=2, padding=padding_width, bias=use_bias, ), norm_layer(ndf * num_filters_multiplier), nn.LeakyReLU(0.2, True), ] num_filters_multiplier_prev = num_filters_multiplier num_filters_multiplier = min(2 ** n_layers, 8) sequence += [ nn.Conv2d( ndf * num_filters_multiplier_prev, ndf * num_filters_multiplier, kernel_size=kw, stride=1, padding=padding_width, bias=use_bias, ), norm_layer(ndf * num_filters_multiplier), nn.LeakyReLU(0.2, True), ] sequence += [ nn.Conv2d( ndf * num_filters_multiplier, 1, kernel_size=kw, stride=1, padding=padding_width, ) ] # output 1 channel prediction map if use_sigmoid: sequence += [nn.Sigmoid()] self.model = nn.Sequential(*sequence)
[docs] def forward(self, input: Tensor) -> Tensor: """Standard forward. Args: input (Tensor): the input tensor """ return self.model(input)