# Source code for simulation.utils.machine_learning.models.test.test_wgan_critic

```"""Perform some basic tests for the WGAN critic."""

import itertools
import pickle

import torch
from hypothesis import given, settings
from hypothesis.strategies import floats, integers

from ..wasserstein_critic import WassersteinCritic

@given(floats(min_value=-0.1, max_value=0), floats(min_value=0, max_value=0.1))
def test_weight_clipping(lower, upper):
if lower == upper:
pass
bounds = (lower, upper)
print(f"Test weight clipping with bounds: {bounds}")
# Since weights are initialized randomly, some are out of
# bounds for almost certain...
critic = WassersteinCritic(input_nc=1, height=32, width=32)

critic._clip_weights(bounds)

# Loop through parameters and ensure that clipping works
for t in critic.parameters():
lower = bounds[0] * torch.ones_like(t)
upper = bounds[1] * torch.ones_like(t)
assert torch.all(t >= lower).item()
assert torch.all(t <= upper).item()

@given(
integers(min_value=1, max_value=16),
integers(min_value=1, max_value=3),
integers(min_value=3, max_value=6).map(lambda x: 2 ** x),
integers(min_value=3, max_value=6).map(lambda x: 2 ** x),
)
def test_forward(batch_size, input_nc, height, width):
print(
f"Test critic input with batch_size:{batch_size},"
f"input_nc:{input_nc}, height:{height}, width:{width}"
)
critic = WassersteinCritic(input_nc=input_nc, height=height, width=width)

input = torch.rand(batch_size, input_nc, height, width)
output = critic(input)
assert output.shape == torch.Size([batch_size, 1])

def test_optimization_step():
"""Testing very basic functionality of optimizing.

Testing if the optimization works is hard. Here, some very basic things are tested:

* Does the wasserstein distance increase when running the optimization?
* Is the distance close to zero, if random distributions are given
and the generator is the identity?
* Is the generator unchanged?
"""
# 1. Distances close to zero if both batches are randomly sampled each iteration
#    and generator is the identity
batch_size, input_nc, height, width = 16, 2, 32, 32

critic = WassersteinCritic(input_nc=input_nc, height=height, width=width)
optimizer = torch.optim.RMSprop(critic.parameters(), lr=0.00005)
generator = torch.nn.Identity()

iterations = 100

distances = [
critic.perform_optimization_step(
generator,
optimizer,
torch.rand(batch_size, input_nc, height, width),
torch.rand(batch_size, input_nc, height, width),
)
for _ in range(iterations)
]

assert sum(distances) / iterations < 0.1

# 2. Distances increase if batches are constant and critic should learn which one
# is which and generator is the identity
critic = WassersteinCritic(input_nc=input_nc, height=height, width=width)
optimizer = torch.optim.RMSprop(critic.parameters(), lr=0.00005)
generator = torch.nn.Identity()

batch_critic, batch_generator = (
torch.rand(batch_size, input_nc, height, width),
torch.rand(batch_size, input_nc, height, width),
)

distances = [
critic.perform_optimization_step(
generator, optimizer, batch_critic, batch_generator
)
for _ in range(iterations)
]

assert sum(distances[iterations // 2 :]) > sum(distances[: iterations // 2])

# 3. Test if the generator's parameters are modified by the critics optimization
#    Should not happen!
critic = WassersteinCritic(input_nc=input_nc, height=height, width=width)

generator = torch.nn.Conv2d(input_nc, input_nc, kernel_size=3)
in_params = list(generator_clone.parameters())

# Add generator parameters to optimizer as well to create a scenario where the critic's
# optimization could change the generator's parameters. It needs to not do that.
optimizer = torch.optim.RMSprop(
itertools.chain(critic.parameters(), generator.parameters()), lr=0.00005
)

batch_critic, batch_generator = (
torch.rand(batch_size, input_nc, height, width),
torch.rand(batch_size, input_nc, height, width),
)

for _ in range(10):
critic.perform_optimization_step(
generator, optimizer, batch_critic, batch_generator
)

out_params = list(generator.parameters())

assert len(in_params) == len(out_params)  # Shouldn't fail anyways...

for i, o in zip(in_params, out_params):
assert torch.all(i == o)

[docs]def main():
test_weight_clipping()
test_forward()
test_optimization_step()

if __name__ == "__main__":
main()
```