simulation.utils.machine_learning.cycle_gan.models package

Submodules

simulation.utils.machine_learning.cycle_gan.models.base_model module

Classes:

CycleGANNetworks(g_a_to_b, g_b_to_a, d_a, d_b)

Container class for all networks used within the CycleGAN.

BaseModel(netg_a_to_b, netg_b_to_a, netd_a, …)

class CycleGANNetworks(g_a_to_b: torch.nn.modules.module.Module, g_b_to_a: torch.nn.modules.module.Module, d_a: Optional[torch.nn.modules.module.Module] = None, d_b: Optional[torch.nn.modules.module.Module] = None)[source]

Bases: object

Container class for all networks used within the CycleGAN.

The CycleGAN generally requires images from two domains a and b. It aims to translate images from one domain to the other.

Attributes:

g_a_to_b

Generator that transforms images from domain a to domain b.

g_b_to_a

Generator that transforms images from domain b to domain a.

d_a

Discrimator that decides for images if they are real or fake in domain a.

d_b

Discrimator that decides for images if they are real or fake in domain b.

Methods:

save(prefix_path)

Save all the networks to the disk.

load(prefix_path, device)

Load all the networks from the disk.

print(verbose)

Print the total number of parameters in the network and (if verbose) network architecture.

g_a_to_b: torch.nn.modules.module.Module

Generator that transforms images from domain a to domain b.

g_b_to_a: torch.nn.modules.module.Module

Generator that transforms images from domain b to domain a.

d_a: torch.nn.modules.module.Module = None

Discrimator that decides for images if they are real or fake in domain a.

d_b: torch.nn.modules.module.Module = None

Discrimator that decides for images if they are real or fake in domain b.

save(prefix_path: str)None[source]

Save all the networks to the disk.

Parameters

prefix_path (str) – the path which gets extended by the model name

load(prefix_path: str, device: torch.device)[source]

Load all the networks from the disk.

Parameters
  • prefix_path (str) – the path which is extended by the model name

  • device (torch.device) – The device on which the networks are loaded

__patch_instance_norm_state_dict(module: torch.nn.modules.module.Module, keys: List[str], i: int = 0)None

Fix InstanceNorm checkpoints incompatibility (prior to 0.4)

Parameters
  • state_dict (dict) – a dict containing parameters from the saved model files

  • module (nn.Module) – the network loaded from a file

  • keys (List[int]) – the keys inside the save file

  • i (int) – current index in network structure

print(verbose: bool)None[source]

Print the total number of parameters in the network and (if verbose) network architecture.

Parameters

verbose (bool) – print the network architecture

class BaseModel(netg_a_to_b, netg_b_to_a, netd_a, netd_b, is_train, lambda_cycle, lambda_idt_a, lambda_idt_b, is_l1, optimizer_type, lr_policy, beta1: float = 0.5, lr: float = 0.0002, cycle_noise_stddev: float = 0)[source]

Bases: abc.ABC, simulation.utils.basics.init_options.InitOptions

Methods:

forward(real_a, real_b)

Run forward pass; called by both functions <optimize_parameters> and <test>.

test(batch_a, batch_b)

Forward function used in test time.

create_schedulers([lr_policy, …])

Create schedulers.

eval()

Make models eval mode during test time.

update_learning_rate()

Update learning rates for all the networks.

do_iteration(batch_a, batch_b)

pre_training()

Attributes:

_abc_impl

forward(real_a, real_b)Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor][source]

Run forward pass; called by both functions <optimize_parameters> and <test>.

_abc_impl = <_abc_data object>
test(batch_a, batch_b)simulation.utils.machine_learning.cycle_gan.models.cycle_gan_stats.CycleGANStats[source]

Forward function used in test time.

This function wraps <forward> function in no_grad() so we don’t save intermediate steps for backpropagation It also calls <compute_visuals> to produce additional visualization results

create_schedulers(lr_policy: str = 'linear', lr_decay_iters: int = 50, lr_step_factor: float = 0.1, n_epochs: int = 100)[source]

Create schedulers.

Parameters
  • lr_policy – learning rate policy. [linear | step | plateau | cosine]

  • lr_decay_iters – multiply by a gamma every lr_decay_iters iterations

  • lr_step_factor – multiply lr with this factor every epoch

  • n_epochs – number of epochs with the initial learning rate

eval()None[source]

Make models eval mode during test time.

update_learning_rate()None[source]

Update learning rates for all the networks.

abstract do_iteration(batch_a: Tuple[torch.Tensor, str], batch_b: Tuple[torch.Tensor, str])[source]
pre_training()[source]

simulation.utils.machine_learning.cycle_gan.models.cycle_gan_model module

Classes:

CycleGANModel(netg_a_to_b, netg_b_to_a, …)

This class implements the CycleGAN model, for learning image-to-image translation without paired data.

class CycleGANModel(netg_a_to_b: torch.nn.modules.module.Module, netg_b_to_a: torch.nn.modules.module.Module, netd_a: Optional[torch.nn.modules.module.Module] = None, netd_b: Optional[torch.nn.modules.module.Module] = None, is_train: bool = True, cycle_noise_stddev: int = 0, pool_size: int = 50, beta1: float = 0.5, lr: float = 0.0002, lr_policy: str = 'linear', lambda_idt_a: int = 10, lambda_idt_b: int = 10, lambda_cycle: float = 0.5, optimizer_type: str = 'adam', is_l1: bool = False)[source]

Bases: simulation.utils.machine_learning.cycle_gan.models.base_model.BaseModel

This class implements the CycleGAN model, for learning image-to-image translation without paired data.

CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf

Methods:

backward_d_basic(netd, real, fake)

Calculate GAN loss for the discriminator.

backward_d_a(real_a, fake_a)

Calculate GAN loss for discriminator D_B.

backward_d_b(real_b, fake_b)

Calculate GAN loss for discriminator D_b.

do_iteration(batch_a, batch_b)

Calculate losses, gradients, and update network weights; called in every training iteration.

Attributes:

_abc_impl

backward_d_basic(netd: torch.nn.modules.module.Module, real: torch.Tensor, fake: torch.Tensor)torch.Tensor[source]

Calculate GAN loss for the discriminator.

We also call loss_d.backward() to calculate the gradients.

Returns

Discriminator loss.

Parameters
  • netd (nn.Module) – the discriminator network

  • real (torch.Tensor) – the real image

  • fake (torch.Tensor) – the fake image

backward_d_a(real_a, fake_a)float[source]

Calculate GAN loss for discriminator D_B.

backward_d_b(real_b, fake_b)float[source]

Calculate GAN loss for discriminator D_b.

do_iteration(batch_a: torch.Tensor, batch_b: torch.Tensor)[source]

Calculate losses, gradients, and update network weights; called in every training iteration.

_abc_impl = <_abc_data object>

simulation.utils.machine_learning.cycle_gan.models.cycle_gan_stats module

Classes:

CycleGANStats(real_a, real_b, fake_a, …)

class CycleGANStats(real_a: torch.Tensor = None, real_b: torch.Tensor = None, fake_a: torch.Tensor = None, fake_b: torch.Tensor = None, rec_a: torch.Tensor = None, rec_b: torch.Tensor = None, idt_a: torch.Tensor = None, idt_b: torch.Tensor = None, loss_g_a_to_b: float = None, loss_g_b_to_a: float = None, loss_idt_a: float = None, loss_idt_b: float = None, loss_cycle_a: float = None, loss_cycle_b: float = None, loss_d_a: float = None, loss_d_b: float = None, w_distance_a: float = None, w_distance_b: float = None)[source]

Bases: object

Attributes:

real_a

real_b

fake_a

fake_b

rec_a

rec_b

idt_a

idt_b

loss_g_a_to_b

loss_g_b_to_a

loss_idt_a

loss_idt_b

loss_cycle_a

loss_cycle_b

loss_d_a

loss_d_b

w_distance_a

w_distance_b

Methods:

get_visuals()

get_losses()

real_a: torch.Tensor = None
real_b: torch.Tensor = None
fake_a: torch.Tensor = None
fake_b: torch.Tensor = None
rec_a: torch.Tensor = None
rec_b: torch.Tensor = None
idt_a: torch.Tensor = None
idt_b: torch.Tensor = None
loss_g_a_to_b: float = None
loss_g_b_to_a: float = None
loss_idt_a: float = None
loss_idt_b: float = None
loss_cycle_a: float = None
loss_cycle_b: float = None
loss_d_a: float = None
loss_d_b: float = None
w_distance_a: float = None
w_distance_b: float = None
get_visuals()[source]
get_losses()[source]

simulation.utils.machine_learning.cycle_gan.models.discriminator module

Functions:

create_discriminator(input_nc, ndf, netd[, …])

Create a discriminator.

create_discriminator(input_nc: int, ndf: int, netd: str, n_layers_d: int = 3, norm: str = 'batch', use_sigmoid: bool = False)torch.nn.modules.module.Module[source]

Create a discriminator.

Returns a discriminator

Our current implementation provides three types of discriminators:

[basic]: ‘PatchGAN’ classifier described in the original pix2pix paper. It can classify whether 70×70 overlapping patches are real or fake. Such a patch-level discriminator architecture has fewer parameters than a full-image discriminator and can work on arbitrarily-sized images in a fully convolutional fashion.

[n_layers]: With this mode, you can specify the number of conv layers in the discriminator with the parameter <n_layers_d> (default=3 as used in [basic] (PatchGAN).)

It uses Leaky RELU for non-linearity.

Parameters
  • input_nc (int) – # of input image channels: 3 for RGB and 1 for grayscale

  • ndf (int) – # of discriminator filters in the first conv layer

  • netd (str) – specify discriminator architecture [basic | n_layers | no_patch]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator

  • n_layers_d (int) – number of layers in the discriminator network

  • norm (str) – instance normalization or batch normalization [instance | batch | none]

  • use_sigmoid (bool) – Use sigmoid activation at the end of discriminator network

simulation.utils.machine_learning.cycle_gan.models.generator module

Functions:

create_generator(input_nc, output_nc, ngf, netg)

Create a generator.

create_generator(input_nc: int, output_nc: int, ngf: int, netg: str, norm: str = 'batch', use_dropout: bool = False, activation: torch.nn.modules.module.Module = Tanh(), conv_layers_in_block: int = 2, dilations: Optional[List[int]] = None)torch.nn.modules.module.Module[source]

Create a generator.

Returns a generator

Our current implementation provides two types of generators.

U-Net:

[unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) The original U-Net paper: https://arxiv.org/abs/1505.04597

Resnet-based generator:

[resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. We adapt Torch code from Justin Johnson’s neural style transfer project (https://github.com/jcjohnson/fast-neural-style).

It uses RELU for non-linearity.

Parameters
  • input_nc (int) – # of input image channels: 3 for RGB and 1 for grayscale

  • output_nc (int) – # of output image channels: 3 for RGB and 1 for grayscale

  • ngf (int) – # of gen filters in the last conv layer

  • netg (str) – specify generator architecture [resnet_<ANY_INTEGER>blocks | unet_256 | unet_128]

  • norm (str) – instance normalization or batch normalization [instance | batch | none]

  • use_dropout (bool) – enable or disable dropout

  • activation (nn.Module) – Choose which activation to use.

  • conv_layers_in_block (int) – specify number of convolution layers per resnet block

  • dilations – dilation for individual conv layers in every resnet block

simulation.utils.machine_learning.cycle_gan.models.n_layer_discriminator module

Classes:

NLayerDiscriminator(input_nc, ndf, n_layers, …)

Defines a PatchGAN discriminator.

class NLayerDiscriminator(input_nc: int, ndf: int = 64, n_layers: int = 3, norm_layer: torch.nn.modules.module.Module = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>, use_sigmoid: bool = True, is_quadratic: bool = True)[source]

Bases: torch.nn.modules.module.Module

Defines a PatchGAN discriminator.

Methods:

forward(input)

Standard forward.

Attributes:

forward(input: torch.Tensor)torch.Tensor[source]

Standard forward.

Parameters

input (Tensor) – the input tensor

training: bool
_is_full_backward_hook: Optional[bool]

simulation.utils.machine_learning.cycle_gan.models.no_patch_discriminator module

Classes:

NoPatchDiscriminator(input_nc, norm_layer, …)

class NoPatchDiscriminator(input_nc: int, norm_layer: torch.nn.modules.module.Module = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>, n_layers_d: int = 4, use_sigmoid: bool = True)[source]

Bases: torch.nn.modules.module.Module

Methods:

forward(x)

Forwarding through network and avg pooling.

Attributes:

forward(x: torch.Tensor)torch.Tensor[source]

Forwarding through network and avg pooling.

Parameters

x (torch.Tensor) – the input tensor

training: bool
_is_full_backward_hook: Optional[bool]

simulation.utils.machine_learning.cycle_gan.models.wcycle_gan module

Classes:

WassersteinCycleGANModel(netg_a_to_b, …[, …])

This class implements the CycleGAN model, for learning image-to-image translation without paired data.

class WassersteinCycleGANModel(netg_a_to_b: torch.nn.modules.module.Module, netg_b_to_a: torch.nn.modules.module.Module, netd_a: Optional[torch.nn.modules.module.Module] = None, netd_b: Optional[torch.nn.modules.module.Module] = None, is_train: bool = True, beta1: float = 0.5, lr: float = 0.0002, lr_policy: str = 'linear', lambda_idt_a: int = 10, lambda_idt_b: int = 10, lambda_cycle: float = 0.5, optimizer_type: str = 'rms_prop', is_l1: bool = False, wgan_n_critic: int = 5, wgan_initial_n_critic: int = 5, wgan_clip_lower=- 0.01, wgan_clip_upper=0.01)[source]

Bases: simulation.utils.machine_learning.cycle_gan.models.base_model.BaseModel

This class implements the CycleGAN model, for learning image-to-image translation without paired data.

By default, it uses a ‘–netg resnet_9blocks’ ResNet generator, a ‘–netd basic’ discriminator (PatchGAN introduced by pix2pix), and a least-square GANs objective (‘–gan_mode lsgan’).

CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf

Methods:

update_critic_a(batch_a, batch_b[, clip_bounds])

update_critic_b(batch_a, batch_b[, clip_bounds])

update_generators(batch_a, batch_b)

pre_training(critic_batches)

do_iteration(batch_a, batch_b, critic_batches)

Attributes:

_abc_impl

update_critic_a(batch_a: torch.Tensor, batch_b: torch.Tensor, clip_bounds: Optional[Tuple[float, float]] = None)[source]
update_critic_b(batch_a: torch.Tensor, batch_b: torch.Tensor, clip_bounds: Optional[Tuple[float, float]] = None)[source]
update_generators(batch_a: torch.Tensor, batch_b: torch.Tensor)[source]
pre_training(critic_batches)[source]
do_iteration(batch_a: torch.Tensor, batch_b: torch.Tensor, critic_batches: List[Tuple[torch.Tensor, torch.Tensor]])[source]
_abc_impl = <_abc_data object>

Module contents