simulation.utils.machine_learning.cycle_gan.models.cycle_gan_model module

Summary

Classes:

CycleGANModel

This class implements the CycleGAN model, for learning image-to-image translation without paired data.

Reference

class CycleGANModel(netg_a_to_b: torch.nn.modules.module.Module, netg_b_to_a: torch.nn.modules.module.Module, netd_a: Optional[torch.nn.modules.module.Module] = None, netd_b: Optional[torch.nn.modules.module.Module] = None, is_train: bool = True, cycle_noise_stddev: int = 0, pool_size: int = 50, beta1: float = 0.5, lr: float = 0.0002, lr_policy: str = 'linear', lambda_idt_a: int = 10, lambda_idt_b: int = 10, lambda_cycle: float = 0.5, optimizer_type: str = 'adam', is_l1: bool = False)[source]

Bases: simulation.utils.machine_learning.cycle_gan.models.base_model.BaseModel

This class implements the CycleGAN model, for learning image-to-image translation without paired data.

CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf

backward_d_basic(netd: torch.nn.modules.module.Module, real: torch.Tensor, fake: torch.Tensor) → torch.Tensor[source]

Calculate GAN loss for the discriminator.

We also call loss_d.backward() to calculate the gradients.

Returns

Discriminator loss.

Parameters
  • netd (nn.Module) – the discriminator network

  • real (torch.Tensor) – the real image

  • fake (torch.Tensor) – the fake image

backward_d_a(real_a, fake_a) → float[source]

Calculate GAN loss for discriminator D_B.

backward_d_b(real_b, fake_b) → float[source]

Calculate GAN loss for discriminator D_b.

do_iteration(batch_a: torch.Tensor, batch_b: torch.Tensor)[source]

Calculate losses, gradients, and update network weights; called in every training iteration.

_abc_impl = <_abc_data object>