Source code for simulation.utils.machine_learning.cycle_gan.configs.train_options

from typing import List

from .base_options import BaseOptions

[docs]class TrainOptions(BaseOptions): dataset_a: List[str] = [ "./../../../../data/real_images/beg_2019", ] """Path to images of domain A (real images). Can be a list of folders.""" dataset_b: List[str] = ["./../../../../data/simulated_images/random_roads"] """Path to images of domain B (simulated images). Can be a list of folders""" display_id: int = 1 """Window id of the web display""" display_port: int = 8097 """Visdom port of the web display""" is_train: bool = True """Enable or disable training mode""" num_threads: int = 8 """# threads for loading data""" save_freq: int = 100 """Frequency of saving the current models""" print_freq: int = 5 """Frequency of showing training results on console""" beta1: float = 0.5 """Momentum term of adam""" batch_size: int = 3 """Input batch size""" lr: float = 0.0005 """Initial learning rate for adam""" lr_decay_iters: int = 1 """Multiply by a gamma every lr_decay_iters iterations""" lr_policy: str = "step" """Learning rate policy. [linear | step | plateau | cosine]""" lr_step_factor: float = 0.1 """Multiplication factor at every step in the step scheduler""" n_epochs: int = 0 """Number of epochs with the initial learning rate""" n_epochs_decay: int = 10 """Number of epochs to linearly decay learning rate to zero""" no_flip: bool = False """Flip 50% of all training images vertically""" continue_train: bool = False """Load checkpoints or start from scratch"""
[docs]class WassersteinCycleGANTrainOptions(TrainOptions): wgan_initial_n_critic: int = 1 """Number of iterations of the critic before starting training loop""" wgan_clip_upper: float = 0.001 """Upper bound for weight clipping""" wgan_clip_lower: float = -0.001 """Lower bound for weight clipping""" wgan_n_critic: int = 5 """Number of iterations of the critic per generator iteration""" is_wgan: bool = True """Decide whether to use wasserstein cycle gan or standard cycle gan"""
[docs]class CycleGANTrainOptions(TrainOptions): pass