from typing import List
from .base_options import BaseOptions
[docs]class TrainOptions(BaseOptions):
dataset_a: List[str] = [
"./../../../../data/real_images/beg_2019",
]
"""Path to images of domain A (real images). Can be a list of folders."""
dataset_b: List[str] = ["./../../../../data/simulated_images/random_roads"]
"""Path to images of domain B (simulated images). Can be a list of folders"""
display_id: int = 1
"""Window id of the web display"""
display_port: int = 8097
"""Visdom port of the web display"""
is_train: bool = True
"""Enable or disable training mode"""
num_threads: int = 8
"""# threads for loading data"""
save_freq: int = 100
"""Frequency of saving the current models"""
print_freq: int = 5
"""Frequency of showing training results on console"""
beta1: float = 0.5
"""Momentum term of adam"""
batch_size: int = 3
"""Input batch size"""
lr: float = 0.0005
"""Initial learning rate for adam"""
lr_decay_iters: int = 1
"""Multiply by a gamma every lr_decay_iters iterations"""
lr_policy: str = "step"
"""Learning rate policy. [linear | step | plateau | cosine]"""
lr_step_factor: float = 0.1
"""Multiplication factor at every step in the step scheduler"""
n_epochs: int = 0
"""Number of epochs with the initial learning rate"""
n_epochs_decay: int = 10
"""Number of epochs to linearly decay learning rate to zero"""
no_flip: bool = False
"""Flip 50% of all training images vertically"""
continue_train: bool = False
"""Load checkpoints or start from scratch"""
[docs]class WassersteinCycleGANTrainOptions(TrainOptions):
wgan_initial_n_critic: int = 1
"""Number of iterations of the critic before starting training loop"""
wgan_clip_upper: float = 0.001
"""Upper bound for weight clipping"""
wgan_clip_lower: float = -0.001
"""Lower bound for weight clipping"""
wgan_n_critic: int = 5
"""Number of iterations of the critic per generator iteration"""
is_wgan: bool = True
"""Decide whether to use wasserstein cycle gan or standard cycle gan"""
[docs]class CycleGANTrainOptions(TrainOptions):
pass