simulation.utils.machine_learning.cycle_gan.models.cycle_gan_model module¶
Summary¶
Classes:
This class implements the CycleGAN model, for learning image-to-image translation without paired data. |
Reference¶
-
class
CycleGANModel(netg_a_to_b: torch.nn.modules.module.Module, netg_b_to_a: torch.nn.modules.module.Module, netd_a: Optional[torch.nn.modules.module.Module] = None, netd_b: Optional[torch.nn.modules.module.Module] = None, is_train: bool = True, cycle_noise_stddev: int = 0, pool_size: int = 50, beta1: float = 0.5, lr: float = 0.0002, lr_policy: str = 'linear', lambda_idt_a: int = 10, lambda_idt_b: int = 10, lambda_cycle: float = 0.5, optimizer_type: str = 'adam', is_l1: bool = False)[source]¶ Bases:
simulation.utils.machine_learning.cycle_gan.models.base_model.BaseModelThis class implements the CycleGAN model, for learning image-to-image translation without paired data.
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
-
backward_d_basic(netd: torch.nn.modules.module.Module, real: torch.Tensor, fake: torch.Tensor) → torch.Tensor[source]¶ Calculate GAN loss for the discriminator.
We also call loss_d.backward() to calculate the gradients.
- Returns
Discriminator loss.
- Parameters
netd (nn.Module) – the discriminator network
real (torch.Tensor) – the real image
fake (torch.Tensor) – the fake image
-
do_iteration(batch_a: torch.Tensor, batch_b: torch.Tensor)[source]¶ Calculate losses, gradients, and update network weights; called in every training iteration.
-
_abc_impl= <_abc_data object>¶
-