simulation.utils.machine_learning.cycle_gan.models package¶
Subpackages¶
Submodules¶
simulation.utils.machine_learning.cycle_gan.models.base_model module¶
Classes:
|
Container class for all networks used within the CycleGAN. |
|
- class CycleGANNetworks(g_a_to_b: torch.nn.modules.module.Module, g_b_to_a: torch.nn.modules.module.Module, d_a: Optional[torch.nn.modules.module.Module] = None, d_b: Optional[torch.nn.modules.module.Module] = None)[source]¶
Bases:
object
Container class for all networks used within the CycleGAN.
The CycleGAN generally requires images from two domains a and b. It aims to translate images from one domain to the other.
Attributes:
Generator that transforms images from domain a to domain b.
Generator that transforms images from domain b to domain a.
Discrimator that decides for images if they are real or fake in domain a.
Discrimator that decides for images if they are real or fake in domain b.
Methods:
save
(prefix_path)Save all the networks to the disk.
load
(prefix_path, device)Load all the networks from the disk.
print
(verbose)Print the total number of parameters in the network and (if verbose) network architecture.
- g_a_to_b: torch.nn.modules.module.Module¶
Generator that transforms images from domain a to domain b.
- g_b_to_a: torch.nn.modules.module.Module¶
Generator that transforms images from domain b to domain a.
- d_a: torch.nn.modules.module.Module = None¶
Discrimator that decides for images if they are real or fake in domain a.
- d_b: torch.nn.modules.module.Module = None¶
Discrimator that decides for images if they are real or fake in domain b.
- save(prefix_path: str) → None[source]¶
Save all the networks to the disk.
- Parameters
prefix_path (str) – the path which gets extended by the model name
- load(prefix_path: str, device: torch.device)[source]¶
Load all the networks from the disk.
- Parameters
prefix_path (str) – the path which is extended by the model name
device (torch.device) – The device on which the networks are loaded
- __patch_instance_norm_state_dict(module: torch.nn.modules.module.Module, keys: List[str], i: int = 0) → None¶
Fix InstanceNorm checkpoints incompatibility (prior to 0.4)
- Parameters
state_dict (dict) – a dict containing parameters from the saved model files
module (nn.Module) – the network loaded from a file
keys (List[int]) – the keys inside the save file
i (int) – current index in network structure
- class BaseModel(netg_a_to_b, netg_b_to_a, netd_a, netd_b, is_train, lambda_cycle, lambda_idt_a, lambda_idt_b, is_l1, optimizer_type, lr_policy, beta1: float = 0.5, lr: float = 0.0002, cycle_noise_stddev: float = 0)[source]¶
Bases:
abc.ABC
,simulation.utils.basics.init_options.InitOptions
Methods:
forward
(real_a, real_b)Run forward pass; called by both functions <optimize_parameters> and <test>.
test
(batch_a, batch_b)Forward function used in test time.
create_schedulers
([lr_policy, …])Create schedulers.
eval
()Make models eval mode during test time.
Update learning rates for all the networks.
do_iteration
(batch_a, batch_b)Attributes:
- forward(real_a, real_b) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor][source]¶
Run forward pass; called by both functions <optimize_parameters> and <test>.
- _abc_impl = <_abc_data object>¶
- test(batch_a, batch_b) → simulation.utils.machine_learning.cycle_gan.models.cycle_gan_stats.CycleGANStats[source]¶
Forward function used in test time.
This function wraps <forward> function in no_grad() so we don’t save intermediate steps for backpropagation It also calls <compute_visuals> to produce additional visualization results
- create_schedulers(lr_policy: str = 'linear', lr_decay_iters: int = 50, lr_step_factor: float = 0.1, n_epochs: int = 100)[source]¶
Create schedulers.
- Parameters
lr_policy – learning rate policy. [linear | step | plateau | cosine]
lr_decay_iters – multiply by a gamma every lr_decay_iters iterations
lr_step_factor – multiply lr with this factor every epoch
n_epochs – number of epochs with the initial learning rate
simulation.utils.machine_learning.cycle_gan.models.cycle_gan_model module¶
Classes:
|
This class implements the CycleGAN model, for learning image-to-image translation without paired data. |
- class CycleGANModel(netg_a_to_b: torch.nn.modules.module.Module, netg_b_to_a: torch.nn.modules.module.Module, netd_a: Optional[torch.nn.modules.module.Module] = None, netd_b: Optional[torch.nn.modules.module.Module] = None, is_train: bool = True, cycle_noise_stddev: int = 0, pool_size: int = 50, beta1: float = 0.5, lr: float = 0.0002, lr_policy: str = 'linear', lambda_idt_a: int = 10, lambda_idt_b: int = 10, lambda_cycle: float = 0.5, optimizer_type: str = 'adam', is_l1: bool = False)[source]¶
Bases:
simulation.utils.machine_learning.cycle_gan.models.base_model.BaseModel
This class implements the CycleGAN model, for learning image-to-image translation without paired data.
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
Methods:
backward_d_basic
(netd, real, fake)Calculate GAN loss for the discriminator.
backward_d_a
(real_a, fake_a)Calculate GAN loss for discriminator D_B.
backward_d_b
(real_b, fake_b)Calculate GAN loss for discriminator D_b.
do_iteration
(batch_a, batch_b)Calculate losses, gradients, and update network weights; called in every training iteration.
Attributes:
- backward_d_basic(netd: torch.nn.modules.module.Module, real: torch.Tensor, fake: torch.Tensor) → torch.Tensor[source]¶
Calculate GAN loss for the discriminator.
We also call loss_d.backward() to calculate the gradients.
- Returns
Discriminator loss.
- Parameters
netd (nn.Module) – the discriminator network
real (torch.Tensor) – the real image
fake (torch.Tensor) – the fake image
- do_iteration(batch_a: torch.Tensor, batch_b: torch.Tensor)[source]¶
Calculate losses, gradients, and update network weights; called in every training iteration.
- _abc_impl = <_abc_data object>¶
simulation.utils.machine_learning.cycle_gan.models.cycle_gan_stats module¶
Classes:
|
- class CycleGANStats(real_a: torch.Tensor = None, real_b: torch.Tensor = None, fake_a: torch.Tensor = None, fake_b: torch.Tensor = None, rec_a: torch.Tensor = None, rec_b: torch.Tensor = None, idt_a: torch.Tensor = None, idt_b: torch.Tensor = None, loss_g_a_to_b: float = None, loss_g_b_to_a: float = None, loss_idt_a: float = None, loss_idt_b: float = None, loss_cycle_a: float = None, loss_cycle_b: float = None, loss_d_a: float = None, loss_d_b: float = None, w_distance_a: float = None, w_distance_b: float = None)[source]¶
Bases:
object
Attributes:
Methods:
- real_a: torch.Tensor = None¶
- real_b: torch.Tensor = None¶
- fake_a: torch.Tensor = None¶
- fake_b: torch.Tensor = None¶
- rec_a: torch.Tensor = None¶
- rec_b: torch.Tensor = None¶
- idt_a: torch.Tensor = None¶
- idt_b: torch.Tensor = None¶
- loss_g_a_to_b: float = None¶
- loss_g_b_to_a: float = None¶
- loss_idt_a: float = None¶
- loss_idt_b: float = None¶
- loss_cycle_a: float = None¶
- loss_cycle_b: float = None¶
- loss_d_a: float = None¶
- loss_d_b: float = None¶
- w_distance_a: float = None¶
- w_distance_b: float = None¶
simulation.utils.machine_learning.cycle_gan.models.discriminator module¶
Functions:
|
Create a discriminator. |
- create_discriminator(input_nc: int, ndf: int, netd: str, n_layers_d: int = 3, norm: str = 'batch', use_sigmoid: bool = False) → torch.nn.modules.module.Module[source]¶
Create a discriminator.
Returns a discriminator
- Our current implementation provides three types of discriminators:
[basic]: ‘PatchGAN’ classifier described in the original pix2pix paper. It can classify whether 70×70 overlapping patches are real or fake. Such a patch-level discriminator architecture has fewer parameters than a full-image discriminator and can work on arbitrarily-sized images in a fully convolutional fashion.
[n_layers]: With this mode, you can specify the number of conv layers in the discriminator with the parameter <n_layers_d> (default=3 as used in [basic] (PatchGAN).)
It uses Leaky RELU for non-linearity.
- Parameters
input_nc (int) – # of input image channels: 3 for RGB and 1 for grayscale
ndf (int) – # of discriminator filters in the first conv layer
netd (str) – specify discriminator architecture [basic | n_layers | no_patch]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator
n_layers_d (int) – number of layers in the discriminator network
norm (str) – instance normalization or batch normalization [instance | batch | none]
use_sigmoid (bool) – Use sigmoid activation at the end of discriminator network
simulation.utils.machine_learning.cycle_gan.models.generator module¶
Functions:
|
Create a generator. |
- create_generator(input_nc: int, output_nc: int, ngf: int, netg: str, norm: str = 'batch', use_dropout: bool = False, activation: torch.nn.modules.module.Module = Tanh(), conv_layers_in_block: int = 2, dilations: Optional[List[int]] = None) → torch.nn.modules.module.Module[source]¶
Create a generator.
Returns a generator
Our current implementation provides two types of generators.
- U-Net:
[unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) The original U-Net paper: https://arxiv.org/abs/1505.04597
- Resnet-based generator:
[resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. We adapt Torch code from Justin Johnson’s neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
It uses RELU for non-linearity.
- Parameters
input_nc (int) – # of input image channels: 3 for RGB and 1 for grayscale
output_nc (int) – # of output image channels: 3 for RGB and 1 for grayscale
ngf (int) – # of gen filters in the last conv layer
netg (str) – specify generator architecture [resnet_<ANY_INTEGER>blocks | unet_256 | unet_128]
norm (str) – instance normalization or batch normalization [instance | batch | none]
use_dropout (bool) – enable or disable dropout
activation (nn.Module) – Choose which activation to use.
conv_layers_in_block (int) – specify number of convolution layers per resnet block
dilations – dilation for individual conv layers in every resnet block
simulation.utils.machine_learning.cycle_gan.models.n_layer_discriminator module¶
Classes:
|
Defines a PatchGAN discriminator. |
- class NLayerDiscriminator(input_nc: int, ndf: int = 64, n_layers: int = 3, norm_layer: torch.nn.modules.module.Module = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>, use_sigmoid: bool = True, is_quadratic: bool = True)[source]¶
Bases:
torch.nn.modules.module.Module
Defines a PatchGAN discriminator.
Methods:
forward
(input)Standard forward.
Attributes:
- forward(input: torch.Tensor) → torch.Tensor[source]¶
Standard forward.
- Parameters
input (Tensor) – the input tensor
- training: bool¶
- _is_full_backward_hook: Optional[bool]¶
simulation.utils.machine_learning.cycle_gan.models.no_patch_discriminator module¶
Classes:
|
- class NoPatchDiscriminator(input_nc: int, norm_layer: torch.nn.modules.module.Module = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>, n_layers_d: int = 4, use_sigmoid: bool = True)[source]¶
Bases:
torch.nn.modules.module.Module
Methods:
forward
(x)Forwarding through network and avg pooling.
Attributes:
- forward(x: torch.Tensor) → torch.Tensor[source]¶
Forwarding through network and avg pooling.
- Parameters
x (torch.Tensor) – the input tensor
- training: bool¶
- _is_full_backward_hook: Optional[bool]¶
simulation.utils.machine_learning.cycle_gan.models.wcycle_gan module¶
Classes:
|
This class implements the CycleGAN model, for learning image-to-image translation without paired data. |
- class WassersteinCycleGANModel(netg_a_to_b: torch.nn.modules.module.Module, netg_b_to_a: torch.nn.modules.module.Module, netd_a: Optional[torch.nn.modules.module.Module] = None, netd_b: Optional[torch.nn.modules.module.Module] = None, is_train: bool = True, beta1: float = 0.5, lr: float = 0.0002, lr_policy: str = 'linear', lambda_idt_a: int = 10, lambda_idt_b: int = 10, lambda_cycle: float = 0.5, optimizer_type: str = 'rms_prop', is_l1: bool = False, wgan_n_critic: int = 5, wgan_initial_n_critic: int = 5, wgan_clip_lower=- 0.01, wgan_clip_upper=0.01)[source]¶
Bases:
simulation.utils.machine_learning.cycle_gan.models.base_model.BaseModel
This class implements the CycleGAN model, for learning image-to-image translation without paired data.
By default, it uses a ‘–netg resnet_9blocks’ ResNet generator, a ‘–netd basic’ discriminator (PatchGAN introduced by pix2pix), and a least-square GANs objective (‘–gan_mode lsgan’).
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
Methods:
update_critic_a
(batch_a, batch_b[, clip_bounds])update_critic_b
(batch_a, batch_b[, clip_bounds])update_generators
(batch_a, batch_b)pre_training
(critic_batches)do_iteration
(batch_a, batch_b, critic_batches)Attributes:
- update_critic_a(batch_a: torch.Tensor, batch_b: torch.Tensor, clip_bounds: Optional[Tuple[float, float]] = None)[source]¶
- update_critic_b(batch_a: torch.Tensor, batch_b: torch.Tensor, clip_bounds: Optional[Tuple[float, float]] = None)[source]¶
- do_iteration(batch_a: torch.Tensor, batch_b: torch.Tensor, critic_batches: List[Tuple[torch.Tensor, torch.Tensor]])[source]¶
- _abc_impl = <_abc_data object>¶