Source code for

import os
import subprocess
import sys
import time
from typing import Any, Dict

import matplotlib.pyplot as plt
import numpy as np
import torch
from visdom import Visdom

from .image_operations import tensor2im

if sys.version_info[0] == 2:
    VisdomExceptionBase = Exception
    VisdomExceptionBase = ConnectionError

[docs]class Visualizer: """This class includes several functions that can display/save images and print/save logging information. It uses a Python library 'visdom' for display. """ def __init__( self, display_id: int = 1, name: str = "kitcar", display_port: int = 8097, checkpoints_dir: str = "./checkpoints", ): """Initialize the Visualizer class. Step 1: Cache the training/test options Step 2: connect to a visdom server Step 3: create a logging file to store training losses Args: display_id (int): window id of the web display name (str): name of the experiment. It decides where to store samples and models display_port (int): visdom port of the web display checkpoints_dir (str): models are saved here """ self.display_id = display_id = name self.port = display_port if self.display_id > 0: # create visdom server instance and connect to it self.create_visdom_connections(self.port) self.vis = Visdom(port=self.port) os.makedirs(os.path.join(checkpoints_dir, name), exist_ok=True) self.log_name = os.path.join(checkpoints_dir, name, "loss_log.txt") with open(self.log_name, "a") as log_file: now = time.strftime("%c") log_file.write("================ Training Loss (%s) ================\n" % now)
[docs] @staticmethod def create_visdom_connections(port: int) -> None: """If the program could not connect to Visdom server, this function will start a new server at port <self.port>""" subprocess.Popen( ["visdom", "-p", str(port)], shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) print(f"Launched Visdom server: http://localhost:{port}")
[docs] def show_hyperparameters(self, hyperparameters: Dict[str, Any]): """Create a html table with all parameters from the dict and displays it on visdom. Args: hyperparameters: a dict containing all hyperparameters """ html = ( "<style>" "table" "{ border: 1px solid white; width: 500px; height: 200px;" "text-align: center; border-collapse: collapse; }" "table td, table th" "{ border: 1px solid #FFFFFF; padding: 3px 2px; }" "table tbody td" "{font-size: 13px;}" "table tr:nth-child(even)" "{ background: #D0E4F5; }" "table thead" "{ background: #0B6FA4; border-bottom: 5px solid #FFFFFF; }" "table thead th {" "font-size: 17px; font-weight: bold; color: white;" "text-align: center; border-left: 2px solid white;}" "table thead th:first-child" "{ border-left: none; }" "</style> " ) html += "<h1>Hyperparameters</h1>" html += "<table>" html += "<thead><tr><th>Key</th><th>Value</th></tr></thead>" html += "<tbody>" for key, value in hyperparameters.items(): html += f"<tr><td>{key}</td><td>{value}</td></tr>" html += "</tbody></table>" self.vis.text(html)
[docs] def display_current_results( self, visuals: Dict[str, torch.Tensor], images_per_row: int = 4 ): """Display current results on visdom. Args: visuals: dictionary of images to display or save images_per_row: Amount of images per row """ if self.display_id > 0: # show images in the browser using visdom # create a table of images. title = images = [] idx = 0 image_numpy = None for label, image in visuals.items(): image_numpy = tensor2im(image) images.append(image_numpy.transpose([2, 0, 1])) idx += 1 white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 while idx % images_per_row != 0: images.append(white_image) idx += 1 self.vis.images( images, nrow=images_per_row, win=str(self.display_id + 1), padding=2, opts=dict(title=title + " images"), )
[docs] def plot_current_losses(self, epoch: int, counter_ratio: float, losses: dict) -> None: """display the current losses on visdom display: dictionary of error labels and values. Args: epoch: current epoch counter_ratio: progress (percentage) in the current epoch, between 0 to 1 losses: training losses stored in the format of (name, float) pairs """ if not hasattr(self, "plot_data"): self.plot_data = {"X": [], "Y": [], "legend": list(losses.keys())} self.plot_data["X"].append(epoch + counter_ratio) self.plot_data["Y"].append([losses[k] for k in self.plot_data["legend"]]) self.vis.line( X=np.stack([np.array(self.plot_data["X"])] * len(self.plot_data["legend"]), 1), Y=np.array(self.plot_data["Y"]), opts={ "title": + " loss over time", "legend": self.plot_data["legend"], "xlabel": "epoch", "ylabel": "loss", }, win=str(self.display_id), )
[docs] def save_losses_as_image(self, path: str): """Save the tracked losses as png file. Args: path: The path where the loss image should be stored """ # Create figure fig = plt.figure() # Create sub plot ax = fig.add_subplot(1, 1, 1) # transpose y data, as self.plot_data['Y'] has the wrong format for matplotlib y_data_transposed = [[] for _ in range(8)] for entry in self.plot_data["Y"]: for i, item in enumerate(entry): y_data_transposed[i].append(item) # Adding all losses to the figure for y_data, label in zip(y_data_transposed, self.plot_data["legend"]): ax.plot(self.plot_data["X"], y_data, label=label) ax.legend() # Save plot plt.savefig(path)
[docs] def print_current_losses( self, epoch: int, iters: int, losses: dict, t_comp: float, estimated_time: float ) -> None: """print current losses on console; also save the losses to the disk. Args: epoch (int): current epoch iters (int): current training iteration during this epoch (reset to 0 at the end of every epoch) losses (dict): training losses stored in the format of (name, float) pairs t_comp (float): computational time per data point (normalized by batch_size) estimated_time (float): the estimated time until training finishes """ hours, remainder = divmod(estimated_time, 3600) minutes, seconds = divmod(remainder, 60) message = ( f"(epoch: {epoch}, iters: {iters}, time: {t_comp:.3f}," f"ETA: {hours:02.0f}:{minutes:02.0f}:{seconds:02.0f} hh:mm:ss) " ) for k, v in losses.items(): message += f"{k}: {v:.3f} " print(message) # print the message with open(self.log_name, "a") as log_file: log_file.write("%s\n" % message) # save the message