Source code for simulation.utils.machine_learning.models.wasserstein_critic
import functools
from typing import List, Optional, Tuple
import torch
from torch import nn as nn
from torch.nn import Flatten
from simulation.utils.basics.init_options import InitOptions
from .helper import get_norm_layer
from .resnet_block import ResnetBlock
[docs]class WassersteinCritic(nn.Module, InitOptions):
def __init__(
self,
input_nc: int,
n_blocks: int = 3,
norm: str = "instance",
ndf=32,
height=256,
width=256,
use_dropout: bool = False,
padding_type: str = "reflect",
conv_layers_in_block: int = 2,
dilations: Optional[List[int]] = None,
):
"""WGAN Critic.
Implementation follows https://github.com/martinarjovsky/WassersteinGAN
Args:
input_nc: Number of channels in input images
norm: Normalization layer
n_blocks: Number of resnet blocks
ndf: Number of features in conv layers
height: Height of the input image
width: Width of the input image
use_dropout: Indicate usage of dropout in resnet blocks
padding_type: Type of padding to be used
conv_layers_in_block: Number of convolution layers in each resnet block
dilations: Type of dilations within each resnet block
"""
super().__init__()
norm_layer = get_norm_layer(norm)
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ndf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ndf),
nn.ReLU(True),
]
dilations = (
[1 for _ in range(conv_layers_in_block)] if dilations is None else dilations
)
n_downsampling = 2
for i in range(n_downsampling): # add downsampling layers
multiplier = 2 ** i
model += [
nn.Conv2d(
ndf * multiplier,
ndf * multiplier * 2,
kernel_size=3,
stride=2,
padding=1,
bias=use_bias,
),
norm_layer(ndf * multiplier * 2),
nn.ReLU(True),
]
multiplier = 2 ** n_downsampling
for i in range(n_blocks): # add ResNet blocks
model += [
ResnetBlock(
ndf * multiplier,
padding_type=padding_type,
norm_layer=norm_layer,
use_dropout=use_dropout,
use_bias=use_bias,
n_conv_layers=conv_layers_in_block,
dilations=dilations,
)
]
model.append(Flatten())
model += [
nn.Linear(
int(ndf * multiplier * height * width / pow(2, 2 * n_downsampling)), 1
),
]
self.model = nn.Sequential(*model)
[docs] def _clip_weights(self, bounds: Tuple[float, float] = (-0.01, 0.01)):
"""Clip weights to given bounds."""
# Clip weights of discriminator
for p in self.parameters():
p.data.clamp_(*bounds)
[docs] def perform_optimization_step(
self,
generator: nn.Module,
optimizer: torch.optim.Optimizer,
batch_critic: torch.Tensor,
batch_generator: torch.Tensor,
weight_clips: Tuple[float, float] = None,
) -> float:
"""Do one iteration to update the parameters.
Args:
generator: Generation network
optimizer: Optimizer for the critic's weights
batch_critic: A batch of inputs for the critic
batch_generator: A batch of inputs for the generator
weight_clips: Optional weight bounds for the critic's weights
Return:
Current wasserstein distance estimated by critic.
"""
"""Attempt to use WGAN divergence instead of weight clipping.
from torch.autograd import Variable
import torch.autograd as autograd
p = 1
batch = Variable(batch_critic.type(torch.Tensor), requires_grad=True)
grad_out = Variable(
torch.Tensor(batch.size(0), 1).fill_(1.0), requires_grad=False
).to(self.device)
grad = autograd.grad(
self(batch),
batch,
grad_out,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
grad_norm = grad.view(grad.size(0), -1).pow(2).sum(1) ** (p / 2)
# Loss: (gradient ascent)
loss_d = (-torch.mean(f_x) + torch.mean(f_g_x)) + 1 / 2 * torch.mean(grad_norm).to(
self.device
)
"""
optimizer.zero_grad()
# Batch 1 into critic
f_x = self(batch_critic)
# Batch 2 in generator
g_x = generator(batch_generator).detach()
# Batch from generator in critic
f_g_x = self(g_x)
# Loss: (gradient ascent)
loss_d = -1 * (torch.mean(f_x) - torch.mean(f_g_x))
loss_d.backward()
optimizer.step()
if weight_clips is not None:
self._clip_weights(weight_clips)
return -1 * loss_d.detach().item()