simulation.utils.machine_learning.cycle_gan.models.wcycle_gan module

Summary

Classes:

WassersteinCycleGANModel

This class implements the CycleGAN model, for learning image-to-image translation without paired data.

Reference

class WassersteinCycleGANModel(netg_a_to_b: torch.nn.modules.module.Module, netg_b_to_a: torch.nn.modules.module.Module, netd_a: Optional[torch.nn.modules.module.Module] = None, netd_b: Optional[torch.nn.modules.module.Module] = None, is_train: bool = True, beta1: float = 0.5, lr: float = 0.0002, lr_policy: str = 'linear', lambda_idt_a: int = 10, lambda_idt_b: int = 10, lambda_cycle: float = 0.5, optimizer_type: str = 'rms_prop', is_l1: bool = False, wgan_n_critic: int = 5, wgan_initial_n_critic: int = 5, wgan_clip_lower=- 0.01, wgan_clip_upper=0.01)[source]

Bases: simulation.utils.machine_learning.cycle_gan.models.base_model.BaseModel

This class implements the CycleGAN model, for learning image-to-image translation without paired data.

By default, it uses a ‘–netg resnet_9blocks’ ResNet generator, a ‘–netd basic’ discriminator (PatchGAN introduced by pix2pix), and a least-square GANs objective (‘–gan_mode lsgan’).

CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf

update_critic_a(batch_a: torch.Tensor, batch_b: torch.Tensor, clip_bounds: Optional[Tuple[float, float]] = None)[source]
update_critic_b(batch_a: torch.Tensor, batch_b: torch.Tensor, clip_bounds: Optional[Tuple[float, float]] = None)[source]
update_generators(batch_a: torch.Tensor, batch_b: torch.Tensor)[source]
pre_training(critic_batches)[source]
do_iteration(batch_a: torch.Tensor, batch_b: torch.Tensor, critic_batches: List[Tuple[torch.Tensor, torch.Tensor]])[source]
_abc_impl = <_abc_data object>