Source code for simulation.utils.machine_learning.cycle_gan.models.test.test_base_model

import os
import pickle
import shutil
from itertools import chain

import torch
from hypothesis import given
from hypothesis.strategies import integers

from ..base_model import CycleGANNetworks
from ..generator import create_generator


[docs]@given(integers(1, 3), integers(1, 3)) def test_save(input_nc, output_nc): print("Test saving networks") net1 = create_generator(input_nc, output_nc, 64, netg="resnet_3blocks") net2 = pickle.loads(pickle.dumps(net1)) nets = CycleGANNetworks(net1, net2) os.makedirs("temp/this_is_a_test", exist_ok=True) nets.save("temp/this_is_a_test/") assert os.path.isfile("temp/this_is_a_test/g_a_to_b.pth") assert os.path.isfile("temp/this_is_a_test/g_b_to_a.pth") shutil.rmtree("temp")
[docs]@given(integers(1, 3), integers(1, 3)) def test_load(input_nc, output_nc): print("Test loading networks") net1 = create_generator(input_nc, output_nc, 64, netg="resnet_3blocks") net2 = pickle.loads(pickle.dumps(net1)) net3 = pickle.loads(pickle.dumps(net1)) net4 = pickle.loads(pickle.dumps(net1)) for param in chain(net1.parameters(), net2.parameters()): param.data = torch.rand(1).expand_as(param.data) nets1 = CycleGANNetworks(net1, net2) os.makedirs("temp/this_is_a_test", exist_ok=True) nets1.save("temp/this_is_a_test/") nets2 = CycleGANNetworks(net3, net4) def model_equals(model1, model2): for p1, p2 in zip(model1.parameters(), model2.parameters()): if not torch.all(torch.eq(p1.data, p2.data)): return False return True assert not model_equals(nets1.g_a_to_b, nets2.g_a_to_b) assert not model_equals(nets1.g_b_to_a, nets2.g_b_to_a) nets2.load("temp/this_is_a_test/", torch.device("cpu")) assert model_equals(nets1.g_a_to_b, nets2.g_a_to_b) assert model_equals(nets1.g_b_to_a, nets2.g_b_to_a) shutil.rmtree("temp")
[docs]def main(): test_save() test_load()
if __name__ == "__main__": main()