simulation.utils.machine_learning.cycle_gan.models.wcycle_gan module¶
Summary¶
Classes:
This class implements the CycleGAN model, for learning image-to-image translation without paired data. |
Reference¶
-
class
WassersteinCycleGANModel(netg_a_to_b: torch.nn.modules.module.Module, netg_b_to_a: torch.nn.modules.module.Module, netd_a: Optional[torch.nn.modules.module.Module] = None, netd_b: Optional[torch.nn.modules.module.Module] = None, is_train: bool = True, beta1: float = 0.5, lr: float = 0.0002, lr_policy: str = 'linear', lambda_idt_a: int = 10, lambda_idt_b: int = 10, lambda_cycle: float = 0.5, optimizer_type: str = 'rms_prop', is_l1: bool = False, wgan_n_critic: int = 5, wgan_initial_n_critic: int = 5, wgan_clip_lower=- 0.01, wgan_clip_upper=0.01)[source]¶ Bases:
simulation.utils.machine_learning.cycle_gan.models.base_model.BaseModelThis class implements the CycleGAN model, for learning image-to-image translation without paired data.
By default, it uses a ‘–netg resnet_9blocks’ ResNet generator, a ‘–netd basic’ discriminator (PatchGAN introduced by pix2pix), and a least-square GANs objective (‘–gan_mode lsgan’).
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
-
update_critic_a(batch_a: torch.Tensor, batch_b: torch.Tensor, clip_bounds: Optional[Tuple[float, float]] = None)[source]¶
-
update_critic_b(batch_a: torch.Tensor, batch_b: torch.Tensor, clip_bounds: Optional[Tuple[float, float]] = None)[source]¶
-
do_iteration(batch_a: torch.Tensor, batch_b: torch.Tensor, critic_batches: List[Tuple[torch.Tensor, torch.Tensor]])[source]¶
-
_abc_impl= <_abc_data object>¶
-