simulation.utils.machine_learning.cycle_gan.models.base_model module

Summary

Classes:

BaseModel

CycleGANNetworks

Container class for all networks used within the CycleGAN.

Reference

class CycleGANNetworks(g_a_to_b: torch.nn.modules.module.Module, g_b_to_a: torch.nn.modules.module.Module, d_a: Optional[torch.nn.modules.module.Module] = None, d_b: Optional[torch.nn.modules.module.Module] = None)[source]

Bases: object

Container class for all networks used within the CycleGAN.

The CycleGAN generally requires images from two domains a and b. It aims to translate images from one domain to the other.

g_a_to_b: torch.nn.modules.module.Module

Generator that transforms images from domain a to domain b.

g_b_to_a: torch.nn.modules.module.Module

Generator that transforms images from domain b to domain a.

d_a: torch.nn.modules.module.Module = None

Discrimator that decides for images if they are real or fake in domain a.

d_b: torch.nn.modules.module.Module = None

Discrimator that decides for images if they are real or fake in domain b.

save(prefix_path: str) → None[source]

Save all the networks to the disk.

Parameters

prefix_path (str) – the path which gets extended by the model name

load(prefix_path: str, device: torch.device)[source]

Load all the networks from the disk.

Parameters
  • prefix_path (str) – the path which is extended by the model name

  • device (torch.device) – The device on which the networks are loaded

__patch_instance_norm_state_dict(module: torch.nn.modules.module.Module, keys: List[str], i: int = 0) → None

Fix InstanceNorm checkpoints incompatibility (prior to 0.4)

Parameters
  • state_dict (dict) – a dict containing parameters from the saved model files

  • module (nn.Module) – the network loaded from a file

  • keys (List[int]) – the keys inside the save file

  • i (int) – current index in network structure

print(verbose: bool) → None[source]

Print the total number of parameters in the network and (if verbose) network architecture.

Parameters

verbose (bool) – print the network architecture

class BaseModel(netg_a_to_b, netg_b_to_a, netd_a, netd_b, is_train, lambda_cycle, lambda_idt_a, lambda_idt_b, is_l1, optimizer_type, lr_policy, beta1: float = 0.5, lr: float = 0.0002, cycle_noise_stddev: float = 0)[source]

Bases: abc.ABC, simulation.utils.basics.init_options.InitOptions

forward(real_a, real_b) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor][source]

Run forward pass; called by both functions <optimize_parameters> and <test>.

_abc_impl = <_abc_data object>
test(batch_a, batch_b)simulation.utils.machine_learning.cycle_gan.models.cycle_gan_stats.CycleGANStats[source]

Forward function used in test time.

This function wraps <forward> function in no_grad() so we don’t save intermediate steps for backpropagation It also calls <compute_visuals> to produce additional visualization results

create_schedulers(lr_policy: str = 'linear', lr_decay_iters: int = 50, lr_step_factor: float = 0.1, n_epochs: int = 100)[source]

Create schedulers.

Parameters
  • lr_policy – learning rate policy. [linear | step | plateau | cosine]

  • lr_decay_iters – multiply by a gamma every lr_decay_iters iterations

  • lr_step_factor – multiply lr with this factor every epoch

  • n_epochs – number of epochs with the initial learning rate

eval() → None[source]

Make models eval mode during test time.

update_learning_rate() → None[source]

Update learning rates for all the networks.

abstract do_iteration(batch_a: Tuple[torch.Tensor, str], batch_b: Tuple[torch.Tensor, str])[source]
pre_training()[source]