Source code for simulation.utils.machine_learning.cycle_gan.models.base_model

import abc
import itertools
import os
import pickle
from abc import ABC
from dataclasses import dataclass
from typing import List, Tuple

import torch
from torch import Tensor, nn
from torch.nn import L1Loss, MSELoss
from torch.optim import RMSprop

from simulation.utils.basics.init_options import InitOptions
from simulation.utils.machine_learning.models import helper

from .cycle_gan_stats import CycleGANStats


[docs]@dataclass class CycleGANNetworks: """Container class for all networks used within the CycleGAN. The CycleGAN generally requires images from two domains a and b. It aims to translate images from one domain to the other. """ g_a_to_b: nn.Module """Generator that transforms images from domain a to domain b.""" g_b_to_a: nn.Module """Generator that transforms images from domain b to domain a.""" d_a: nn.Module = None """Discrimator that decides for images if they are real or fake in domain a.""" d_b: nn.Module = None """Discrimator that decides for images if they are real or fake in domain b."""
[docs] def save(self, prefix_path: str) -> None: """Save all the networks to the disk. Args: prefix_path (str): the path which gets extended by the model name """ for name, net in self.__dict__.items(): if net is None: continue net = pickle.loads(pickle.dumps(net)) save_path = prefix_path + f"{name}.pth" torch.save(net.state_dict(), save_path)
[docs] def load(self, prefix_path: str, device: torch.device): """Load all the networks from the disk. Args: prefix_path (str): the path which is extended by the model name device (torch.device): The device on which the networks are loaded """ for name, net in self.__dict__.items(): if net is None: continue load_path = prefix_path + f"{name}.pth" if not os.path.isfile(load_path): raise FileNotFoundError(f"No model weights file found at {load_path}") if isinstance(net, torch.nn.DataParallel): net = net.module # if you are using PyTorch newer than 0.4 (e.g., built from # GitHub source), you can remove str() on device state_dict = torch.load(load_path, map_location=str(device)) print(f"Loaded: {load_path}") if hasattr(state_dict, "_metadata"): del state_dict._metadata # patch InstanceNorm checkpoints prior to 0.4 for key in list( state_dict.keys() ): # need to copy keys here because we mutate in loop CycleGANNetworks.__patch_instance_norm_state_dict( state_dict, net, key.split(".") ) net.load_state_dict(state_dict)
@staticmethod def __patch_instance_norm_state_dict( state_dict: dict, module: nn.Module, keys: List[str], i: int = 0 ) -> None: """Fix InstanceNorm checkpoints incompatibility (prior to 0.4) Args: state_dict (dict): a dict containing parameters from the saved model files module (nn.Module): the network loaded from a file keys (List[int]): the keys inside the save file i (int): current index in network structure """ key = keys[i] if i + 1 == len(keys): # at the end, pointing to a parameter/buffer if module.__class__.__name__.startswith("InstanceNorm") and ( key == "running_mean" or key == "running_var" ): if getattr(module, key) is None: state_dict.pop(".".join(keys)) if module.__class__.__name__.startswith("InstanceNorm") and ( key == "num_batches_tracked" ): state_dict.pop(".".join(keys)) else: CycleGANNetworks.__patch_instance_norm_state_dict( state_dict, getattr(module, key), keys, i + 1 )
[docs] def print(self, verbose: bool) -> None: """Print the total number of parameters in the network and (if verbose) network architecture. Args: verbose (bool): print the network architecture """ print("---------- Networks initialized -------------") for name, net in self.__dict__.items(): if net is None: continue num_params = 0 for param in net.parameters(): num_params += param.numel() if verbose: print(net) print( "[Network %s] Total number of parameters : %.3f M" % (name, num_params / 1e6) ) print("-----------------------------------------------")
def __iter__(self): return (net for net in self.__dict__.values() if net is not None)
[docs]class BaseModel(ABC, InitOptions): def __init__( self, netg_a_to_b, netg_b_to_a, netd_a, netd_b, is_train, lambda_cycle, lambda_idt_a, lambda_idt_b, is_l1, optimizer_type, lr_policy, beta1: float = 0.5, lr: float = 0.0002, cycle_noise_stddev: float = 0, ): self.is_train = is_train self.lambda_cycle = lambda_cycle self.lambda_idt_a = lambda_idt_a self.lambda_idt_b = lambda_idt_b self.is_l1 = is_l1 self.metric = 0 # used for learning rate policy 'plateau' self.lr_policy = lr_policy self.cycle_noise_stddev = cycle_noise_stddev if is_train else 0 self.networks = CycleGANNetworks(netg_a_to_b, netg_b_to_a, netd_a, netd_b) if self.is_train: # define loss functions self.criterionCycle = L1Loss() if self.is_l1 else MSELoss() self.criterionIdt = L1Loss() if self.is_l1 else MSELoss() if optimizer_type == "rms_prop": self.optimizer_g = RMSprop( itertools.chain( self.networks.g_a_to_b.parameters(), self.networks.g_b_to_a.parameters(), ), lr=lr, ) self.optimizer_d = RMSprop( itertools.chain( self.networks.d_a.parameters(), self.networks.d_b.parameters() ), lr=lr, ) else: self.optimizer_g = torch.optim.Adam( itertools.chain( self.networks.g_a_to_b.parameters(), self.networks.g_b_to_a.parameters(), ), lr=lr, betas=(beta1, 0.999), ) self.optimizer_d = torch.optim.Adam( itertools.chain( self.networks.d_a.parameters(), self.networks.d_b.parameters() ), lr=lr, betas=(beta1, 0.999), )
[docs] def forward(self, real_a, real_b) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Run forward pass; called by both functions <optimize_parameters> and <test>.""" fake_b = self.networks.g_a_to_b(real_a) # G_A(A) fake_a = self.networks.g_b_to_a(real_b) # G_B(B) # Calculate cycle. Add gaussian if self.cycle_noise_stddev is not 0 # See: https://discuss.pytorch.org/t/writing-a-simple-gaussian-noise-layer-in-pytorch/4694 # noqa: E501 # There are two individual noise terms because fake_A and fake_B may # have different dimensions # (At end of dataset were one of them is not a full batch for example) if self.cycle_noise_stddev == 0: noise_a = 0 noise_b = 0 else: noise_a = ( torch.zeros(fake_a.size()) .normal_(0, self.cycle_noise_stddev) .requires_grad_() ) noise_b = ( torch.zeros(fake_b.size()) .normal_(0, self.cycle_noise_stddev) .requires_grad_() ) rec_a = self.networks.g_b_to_a(fake_b + noise_b) rec_b = self.networks.g_a_to_b(fake_a + noise_a) return fake_a, fake_b, rec_a, rec_b
[docs] def test(self, batch_a, batch_b) -> CycleGANStats: """Forward function used in test time. This function wraps <forward> function in no_grad() so we don't save intermediate steps for backpropagation It also calls <compute_visuals> to produce additional visualization results """ with torch.no_grad(): fake_a, fake_b, rec_a, rec_b = self.forward(batch_a, batch_b) return CycleGANStats(batch_a, batch_b, fake_a, fake_b, rec_a, rec_b)
[docs] def create_schedulers( self, lr_policy: str = "linear", lr_decay_iters: int = 50, lr_step_factor: float = 0.1, n_epochs: int = 100, ): """Create schedulers. Args: lr_policy: learning rate policy. [linear | step | plateau | cosine] lr_decay_iters: multiply by a gamma every lr_decay_iters iterations lr_step_factor: multiply lr with this factor every epoch n_epochs: number of epochs with the initial learning rate """ self.schedulers = [ helper.get_scheduler( optimizer, lr_policy, lr_decay_iters, n_epochs, lr_step_factor ) for optimizer in [self.optimizer_d, self.optimizer_g] ]
[docs] def eval(self) -> None: """Make models eval mode during test time.""" for net in self.networks: net.eval()
[docs] def update_learning_rate(self) -> None: """Update learning rates for all the networks.""" old_lr = self.optimizer_g.param_groups[0]["lr"] for scheduler in self.schedulers: if self.lr_policy == "plateau": scheduler.step(self.metric) else: scheduler.step() lr = self.optimizer_g.param_groups[0]["lr"] print(f"learning rate {old_lr:.7f} -> {lr:.7f}")
[docs] @abc.abstractmethod def do_iteration( self, batch_a: Tuple[torch.Tensor, str], batch_b: Tuple[torch.Tensor, str] ): raise NotImplementedError("Abstract method!")
[docs] def pre_training(self): pass