- simulation.utils.machine_learning.cycle_gan.configs package
- simulation.utils.machine_learning.cycle_gan.models package
- simulation.utils.machine_learning.cycle_gan.models.base_model module
- simulation.utils.machine_learning.cycle_gan.models.cycle_gan_model module
- simulation.utils.machine_learning.cycle_gan.models.cycle_gan_stats module
- simulation.utils.machine_learning.cycle_gan.models.discriminator module
- simulation.utils.machine_learning.cycle_gan.models.generator module
- simulation.utils.machine_learning.cycle_gan.models.n_layer_discriminator module
- simulation.utils.machine_learning.cycle_gan.models.no_patch_discriminator module
- simulation.utils.machine_learning.cycle_gan.models.wcycle_gan module
- Module contents
Implementation of a simple ROS interface to translate simulated to “real” images.
Test one dataset and save all images.
- test_on_dataset(dataset: simulation.utils.machine_learning.data.data_loader.DataLoader, generators: Tuple[torch.nn.modules.module.Module, torch.nn.modules.module.Module], class_names: Tuple[str, str], destination: str, aspect_ratio: float = 1, device: torch.device = device(type='cuda', index=0))[source]¶
Test one dataset and save all images.
dataset – The dataset to test
generators – Both generators. Second one is used to generate the fake image.
class_names – The class names to save the images correctly.
destination – The destination folder
aspect_ratio – The aspect ratio of the images
device – the device on which the models are located
The Cycle GAN can be used to convert simulated images into real looking images.
During training, a class A image is “translated” to class B using a generator and “retranslated” to class A using another generator. The difference between these images is the error of the two networks, which should be minimized by the training process. This process also goes from class B to A and back to B. The discriminators are used to evaluate whether an image originates from the class A or class B. They are trained for this classification at the same time.
PyTorch is used for the implementation.
The dataset consists of 2 classes of images.
real images (Class A)
simulated images (Class B)
DVC is used for dataset management. With DVC even large datasets can be versioned quickly and easily. The data sets are located under data/real_images and data/simulated_images. With dvc pull the data sets are downloaded:
Also trained models are stored with DVC. A new model can be added by running with dvc commit.
dvc commit dvc push
After a git checkout there should always be a dvc checkout, otherwise the data can get mixed up.
This is only a short summary of how dvc works, a better explanation can be found here: DVC Tutorial
All parameters for training can be found in
A new model can be trained with the training script.
The training script automatically starts a visdom server, which shows the current state of the training. The script automatically saves the model in between, these intermediate states are stored in the checkpoints folder.
For testing, checkpoints must already exist.
All parameters for testing can be found in
There are two ways to test a model.
With DVC stages
For testing purposes there are the following DVC stages:
A DVC stage is executed with the following command:
dvc repro -s NAME_OF_STAGE
test_dr_drift: Tests the model by loading the last checkpoints from the folder checkpoints/dr_drift and creates the results folder with the test results
make_video_dr_drift: Takes the results from the results folder and cuts 3 videos together. One with the pictures from the simulation, another one with the “translated” images and one where the two videos were stacked on top of each other others.
stacked_gif_dr_drift: Converts the stacked video into a short gif.
Without DVC stages
A model can be tested with the testing script.
The test script creates a new folder with all generated images in it.
The resulting images can be cut into a video with the script
With CML in the pipeline, each new model is automatically tested and the results are presented in a report under the commit or merge request. This requires that the model be named “dr_drift”.