Source code for drytorch.trackers.visdom

"""Module containing a tracker for visualizations on visdom."""

import functools

from collections.abc import Iterable
from typing import Any, TypedDict

import numpy as np
import numpy.typing as npt
import visdom

from typing_extensions import override

from drytorch.core import exceptions, log_events
from drytorch.trackers import base_classes


__all__ = [
    'VisdomOpts',
    'VisdomPlotter',
]


[docs] class VisdomOpts(TypedDict, total=False): """Annotations for optional settings in visdom. See: https://github.com/fossasia/visdom/blob/master/py/visdom/__init__.py. """ title: str xlabel: str ylabel: str zlabel: str # Layout & Size width: int height: int marginleft: int marginright: int margintop: int marginbottom: int # Line style fillarea: bool markers: bool markersymbol: str markersize: float markercolor: npt.NDArray[np.str_] markerborderwidth: float mode: str dash: npt.NDArray[np.str_] # e.g., np.array(['solid', 'dash']) # Axis limits xmin: float xmax: float ymin: float ymax: float # Colors & Style linecolor: npt.NDArray[Any] # shape: (n_lines, 3) RGB colormap: str # matplotlib-style colormap name (for some plots) # Legend legend: list[str] | tuple[str, ...] showlegend: bool # Fonts titlefont: dict[str, Any] # e.g., {"size": 16, "color": "red"} tickfont: dict[str, Any] # Misc opacity: float name: str textlabels: list[str] # used in bar, scatter, etc.
[docs] class VisdomPlotter(base_classes.BasePlotter[str]): """Tracker that uses visdom as backend. Attributes: server: the address for the visdom server. port: the port for the server. opts: plot options. """ server: str port: int opts: VisdomOpts _viz: visdom.Visdom | None def __init__( self, server: str = 'http://localhost', port: int = 8097, opts: VisdomOpts | None = None, model_names: Iterable[str] = (), source_names: Iterable[str] = (), metric_names: Iterable[str] = (), metric_loader: base_classes.MetricLoader | None = None, start: int = 1, ) -> None: """Initialize. Args: server: the address of the server. port: the port for the connection: opts: plot options. model_names: the names of the models to plot. Defaults to all. metric_names: the names of the metrics to plot. Defaults to all. source_names: the names of the sources to plot. Defaults to all. metric_loader: a tracker that can load metrics from a previous run. start: if positive, the epoch from which to start plotting; if negative, the last number of epochs. Defaults to all. """ super().__init__( model_names, source_names, metric_names, start, metric_loader ) self.server = server self.port = port self.opts = opts or VisdomOpts() self._viz = None return @property def viz(self) -> visdom.Visdom: """The active Visdom instance. Raises: AccessOutsideScopeError: if no run has been started yet. """ if self._viz is None: raise exceptions.AccessOutsideScopeError() return self._viz
[docs] @override def clean_up(self) -> None: self._viz = None return super().clean_up()
[docs] @functools.singledispatchmethod @override def notify(self, event: log_events.Event) -> None: return super().notify(event)
@notify.register def _(self, event: log_events.StartExperimentEvent) -> None: """Start the experiment and connect to the server. Raises: TrackerError: if the server is not available. """ env = event.exp_name + '_' + event.run_id try: self._viz = visdom.Visdom( server=self.server, port=self.port, env=env, raise_exceptions=True, ) except ConnectionError as cre: msg = 'server not available.' raise exceptions.TrackerError(self, msg) from cre else: self.viz.close(env=env) # close all the windows return super().notify(event) @override def _plot_metric( self, model_name: str, metric_name: str, **sourced_array: base_classes.NpArray, ) -> str: layout = VisdomOpts( xlabel='Epoch', ylabel=metric_name, title=model_name, showlegend=True, ) scatter_opts = VisdomOpts(mode='markers', markersymbol='24') opts = self.opts | layout # type: ignore win = '_'.join((model_name, metric_name)) for name, log in sourced_array.items(): self.viz.scatter(None, win=win, update='remove', name=name) if log.shape[0] > 1: self.viz.line( X=log[:, 0], Y=log[:, 1], opts=opts, update='append', win=win, name=name, ) else: self.viz.scatter( X=log[:, 0], Y=log[:, 1], opts=opts | scatter_opts, update='append', win=win, name=name, ) return win