Source code for drytorch.trackers.wandb

"""Module containing a tracker calling Weights and Biases."""

import functools
import pathlib
import warnings

from typing import ClassVar

import wandb

from typing_extensions import override
from wandb.sdk import wandb_run, wandb_settings

from drytorch.core import exceptions, log_events
from drytorch.trackers.base_classes import Dumper
from drytorch.utils import repr_utils


__all__ = [
    'Wandb',
    'WandbWarning',
]


[docs] class WandbWarning(exceptions.DryTorchWarning): """Warning class for wandb.""" _template = 'Wandb: {}'
[docs] class Wandb(Dumper): """Tracker that wraps a run for the wandb library.""" _default_settings: ClassVar[wandb_settings.Settings] = ( wandb_settings.Settings() ) folder_name = 'wandb' _settings: wandb_settings.Settings _run: wandb_run.Run | None _defined_metrics: set[str] def __init__( self, par_dir: pathlib.Path | None = None, settings: wandb_settings.Settings = _default_settings, ) -> None: """Initialize. Args: par_dir: the parent directory for the tracker data. Default uses the same of the current experiment. settings: settings object from wandb containing all init arguments. """ super().__init__(par_dir) self._settings = settings self._run = None self._defined_metrics = set() return @property def run(self) -> wandb_run.Run: """Active wandb run instance. Raises: AccessOutsideScopeError: if no run has been started yet. """ if self._run is None: raise exceptions.AccessOutsideScopeError() return self._run
[docs] @override def clean_up(self) -> None: try: wandb.finish() except Exception as e: warnings.warn( WandbWarning(f'Error during cleanup: {e}'), stacklevel=1 ) finally: self._run = None return
[docs] @functools.singledispatchmethod @override def notify(self, event: log_events.Event) -> None: return super().notify(event)
@notify.register def _(self, event: log_events.StartExperimentEvent) -> None: super().notify(event) project = self._settings.project or event.exp_name group = self._settings.run_group or event.exp_name run_id = '' if event.resumed: api = wandb.Api() entity = self._settings.entity or api.default_entity runs = api.runs( f'{entity}/{project}', filters={'group': event.exp_name}, ) try: run_id = runs[0].id except (IndexError, ValueError): msg = 'No previous runs. Starting a new one.' warnings.warn(WandbWarning(msg), stacklevel=2) if self._settings.run_id: run_id = self._settings.run_id if not run_id: run_id = event.exp_name + '_' + event.run_id repr_config = repr_utils.recursive_repr(event.config, depth=1000) self._run = wandb.init( id=run_id, dir=self.par_dir.as_posix(), project=project, group=group, config=repr_config, tags=event.tags, settings=self._settings, resume='allow' if event.resumed else None, ) return @notify.register def _(self, event: log_events.StopExperimentEvent) -> None: self.clean_up() return super().notify(event) @notify.register def _(self, event: log_events.MetricEvent) -> None: """Process metric events. Raises: AccessOutsideScopeError: if called outside an active run scope. """ if self.run is None: raise exceptions.AccessOutsideScopeError() plot_names = { f'{event.model_name}/{event.source_name}-{name}': value for name, value in event.metrics.items() } step_key = f'Progress/{event.model_name}' plot_step = {step_key: event.epoch} # define new metrics only once for name in plot_names: if name not in self._defined_metrics: self.run.define_metric(name, step_metric=step_key) self._defined_metrics.add(name) self.run.log(plot_names | plot_step) return super().notify(event)