Source code for simulation.utils.machine_learning.cycle_gan.models.no_patch_discriminator

import torch
from torch import nn as nn


[docs]class NoPatchDiscriminator(nn.Module): def __init__( self, input_nc: int, norm_layer: nn.Module = nn.BatchNorm2d, n_layers_d: int = 4, use_sigmoid: bool = True, ): """Construct a no patch gan discriminator. Args: input_nc (int): the number of channels in input images norm_layer (nn.Module): normalization layer n_layers_d (int): the number of convolution blocks use_sigmoid (bool): sigmoid activation at the end """ super().__init__() self.use_sigmoid = use_sigmoid # A bunch of convolutions one after another model = [ nn.Conv2d(input_nc, 64, 3, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), ] model += [ nn.Conv2d(64, 64, 3, stride=2, padding=1), norm_layer(64), nn.LeakyReLU(0.2, inplace=True), ] input_nc = 64 for i in range(n_layers_d): model += [ nn.Conv2d(input_nc, input_nc * 2, 4, stride=2, padding=1), norm_layer(input_nc * 2), nn.LeakyReLU(0.2, inplace=True), ] input_nc *= 2 model += [nn.Conv2d(512, 1, 8, stride=2, padding=1)] self.model = nn.Sequential(*model)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forwarding through network and avg pooling. Args: x (torch.Tensor): the input tensor """ x = self.model(x) # Average pooling and flatten output = torch.nn.functional.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1) return torch.sigmoid(output) if self.use_sigmoid else output