Source code for simulation.utils.machine_learning.models.test.test_resnet_block

import torch

from .. import helper, resnet_block


[docs]def test_creating_resnet_block(norm_type, **kwargs): kwargs["norm_layer"] = helper.get_norm_layer(norm_type) tensor = torch.rand((3, kwargs["dim"], 30, 30)) block = resnet_block.ResnetBlock(**kwargs) assert isinstance(block(tensor), torch.Tensor)
[docs]def main(): test_creating_resnet_block( dim=4, padding_type="replicate", norm_type="instance", use_dropout=True, use_bias=True, ) test_creating_resnet_block( dim=2, padding_type="reflect", norm_type="none", use_dropout=False, use_bias=False, n_conv_layers=3, dilations=[2, 4, 2], ) test_creating_resnet_block( dim=4, padding_type="replicate", norm_type="instance", use_dropout=False, use_bias=False, n_conv_layers=6, dilations=None, )
if __name__ == "__main__": main()