import argparse
import os
import pickle
from typing import Tuple
import torch
from torch import nn
import simulation.utils.machine_learning.data as ml_data
from simulation.utils.machine_learning.cycle_gan.configs.test_options import (
CycleGANTestOptions,
WassersteinCycleGANTestOptions,
)
from simulation.utils.machine_learning.cycle_gan.models.cycle_gan_model import CycleGANModel
from simulation.utils.machine_learning.cycle_gan.models.generator import create_generator
from simulation.utils.machine_learning.cycle_gan.models.wcycle_gan import (
WassersteinCycleGANModel,
)
from simulation.utils.machine_learning.data import DataLoader
from simulation.utils.machine_learning.data.image_operations import save_images
from simulation.utils.machine_learning.models.helper import get_norm_layer, init_net
from simulation.utils.machine_learning.models.resnet_generator import ResnetGenerator
[docs]def test_on_dataset(
dataset: DataLoader,
generators: Tuple[nn.Module, nn.Module],
class_names: Tuple[str, str],
destination: str,
aspect_ratio: float = 1,
device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
):
"""Test one dataset and save all images.
Args:
dataset: The dataset to test
generators: Both generators. Second one is used to generate the fake image.
class_names: The class names to save the images correctly.
destination: The destination folder
aspect_ratio: The aspect ratio of the images
device: the device on which the models are located
"""
for i, (real_image, _) in enumerate(dataset):
real_image = real_image.to(device)
fake_image = generators[0](real_image)
idt_image = generators[1](real_image)
cycle_image = generators[1](fake_image)
visuals = {
f"real_{class_names[0]}": real_image,
f"fake_{class_names[1]}": fake_image,
f"idt_{class_names[0]}": idt_image,
f"cycle_{class_names[0]}": cycle_image,
}
print(f"Processing {i}-th image on dataset {class_names[0]}.")
save_images(
visuals=visuals,
destination=destination,
aspect_ratio=aspect_ratio,
post_fix=str(i),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--gan_type",
type=str,
default="default",
help="Decide whether to use Wasserstein gan or default gan [default, wgan]",
)
use_wasserstein = parser.parse_args().gan_type == "wgan"
opt = WassersteinCycleGANTestOptions if use_wasserstein else CycleGANTestOptions
tf_properties = {
"load_size": opt.load_size,
"crop_size": opt.crop_size,
"preprocess": opt.preprocess,
"mask": opt.mask,
}
dataset_a, dataset_b = ml_data.load_unpaired_unlabeled_datasets(
opt.dataset_a,
opt.dataset_b,
batch_size=1,
sequential=True,
num_threads=0,
grayscale_a=(opt.input_nc == 1),
grayscale_b=(opt.output_nc == 1),
max_dataset_size=opt.max_dataset_size,
transform_properties=tf_properties,
) # create datasets for each domain (A and B)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if opt.is_wgan:
netg_a_to_b = ResnetGenerator(
opt.input_nc,
opt.output_nc,
opt.ngf,
get_norm_layer(opt.norm),
dilations=opt.dilations,
conv_layers_in_block=opt.conv_layers_in_block,
)
else:
netg_a_to_b = create_generator(
opt.input_nc,
opt.output_nc,
opt.ngf,
opt.netg,
opt.norm,
not opt.no_dropout,
opt.activation,
opt.conv_layers_in_block,
opt.dilations,
)
netg_b_to_a = pickle.loads(pickle.dumps(netg_a_to_b))
netg_a_to_b = init_net(netg_a_to_b, opt.init_type, opt.init_gain, device)
netg_b_to_a = init_net(netg_b_to_a, opt.init_type, opt.init_gain, device)
ModelClass = CycleGANModel if not opt.is_wgan else WassersteinCycleGANModel
model = ModelClass.from_dict(
netg_a_to_b=netg_a_to_b, netg_b_to_a=netg_b_to_a, **opt.to_dict()
)
model.networks.load(
os.path.join(opt.checkpoints_dir, opt.name, f"{opt.epoch}_net_"),
device=device,
)
model.networks.print(opt.verbose)
model.eval()
test_on_dataset(
dataset_a,
(model.networks.g_a_to_b, model.networks.g_b_to_a),
("a", "b"),
destination=os.path.join(opt.results_dir, opt.name),
aspect_ratio=opt.aspect_ratio,
device=device,
)
test_on_dataset(
dataset_b,
(model.networks.g_b_to_a, model.networks.g_a_to_b),
("b", "a"),
destination=os.path.join(opt.results_dir, opt.name),
aspect_ratio=opt.aspect_ratio,
device=device,
)