import random
import torch
[docs]class ImagePool:
"""This class implements an image buffer that stores previously generated images.
This buffer enables us to update discriminators using a history of generated images
rather than the ones produced by the latest generators.
"""
def __init__(self, pool_size: int):
"""Initialize the ImagePool class.
Args:
pool_size (int): the size of image buffer,
if pool_size=0 no buffer will be created
"""
self.pool_size = pool_size
if self.pool_size > 0: # create an empty pool
self.num_images = 0
self.images = []
[docs] def query(self, images: torch.Tensor) -> torch.Tensor:
"""Return an image from the pool.
Returns images from the buffer.
By 50/100, the buffer will return input images.
By 50/100, the buffer will return images previously stored in
the buffer, and insert the current images to the buffer.
Args:
images (torch.Tensor): the latest generated images from the generator
"""
if self.pool_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
if (
self.num_images < self.pool_size
): # if the buffer is not full; keep inserting current images to the
# buffer
self.num_images = self.num_images + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if (
p > 0.5
): # by 50% chance, the buffer will return a previously stored image,
# and insert the current image into the buffer
random_id = random.randint(
0, self.pool_size - 1
) # randint is inclusive
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else: # by another 50% chance, the buffer will return the current image
return_images.append(image)
return_images = torch.cat(return_images, 0) # collect all the images and return
return return_images