simulation.utils.machine_learning.cycle_gan.models.base_model module¶
Reference¶
-
class
CycleGANNetworks(g_a_to_b: torch.nn.modules.module.Module, g_b_to_a: torch.nn.modules.module.Module, d_a: Optional[torch.nn.modules.module.Module] = None, d_b: Optional[torch.nn.modules.module.Module] = None)[source]¶ Bases:
objectContainer class for all networks used within the CycleGAN.
The CycleGAN generally requires images from two domains a and b. It aims to translate images from one domain to the other.
-
g_a_to_b: torch.nn.modules.module.Module¶ Generator that transforms images from domain a to domain b.
-
g_b_to_a: torch.nn.modules.module.Module¶ Generator that transforms images from domain b to domain a.
-
d_a: torch.nn.modules.module.Module = None¶ Discrimator that decides for images if they are real or fake in domain a.
-
d_b: torch.nn.modules.module.Module = None¶ Discrimator that decides for images if they are real or fake in domain b.
-
save(prefix_path: str) → None[source]¶ Save all the networks to the disk.
- Parameters
prefix_path (str) – the path which gets extended by the model name
-
load(prefix_path: str, device: torch.device)[source]¶ Load all the networks from the disk.
- Parameters
prefix_path (str) – the path which is extended by the model name
device (torch.device) – The device on which the networks are loaded
-
__patch_instance_norm_state_dict(module: torch.nn.modules.module.Module, keys: List[str], i: int = 0) → None¶ Fix InstanceNorm checkpoints incompatibility (prior to 0.4)
- Parameters
state_dict (dict) – a dict containing parameters from the saved model files
module (nn.Module) – the network loaded from a file
keys (List[int]) – the keys inside the save file
i (int) – current index in network structure
-
-
class
BaseModel(netg_a_to_b, netg_b_to_a, netd_a, netd_b, is_train, lambda_cycle, lambda_idt_a, lambda_idt_b, is_l1, optimizer_type, lr_policy, beta1: float = 0.5, lr: float = 0.0002, cycle_noise_stddev: float = 0)[source]¶ Bases:
abc.ABC,simulation.utils.basics.init_options.InitOptions-
forward(real_a, real_b) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor][source]¶ Run forward pass; called by both functions <optimize_parameters> and <test>.
-
_abc_impl= <_abc_data object>¶
-
test(batch_a, batch_b) → simulation.utils.machine_learning.cycle_gan.models.cycle_gan_stats.CycleGANStats[source]¶ Forward function used in test time.
This function wraps <forward> function in no_grad() so we don’t save intermediate steps for backpropagation It also calls <compute_visuals> to produce additional visualization results
-
create_schedulers(lr_policy: str = 'linear', lr_decay_iters: int = 50, lr_step_factor: float = 0.1, n_epochs: int = 100)[source]¶ Create schedulers.
- Parameters
lr_policy – learning rate policy. [linear | step | plateau | cosine]
lr_decay_iters – multiply by a gamma every lr_decay_iters iterations
lr_step_factor – multiply lr with this factor every epoch
n_epochs – number of epochs with the initial learning rate
-