"""Module for coordinating logging of metadata, internal messages and metrics.
Attributes:
DEFAULT_TRACKERS: named trackers registered to experiments by default.
"""
from __future__ import annotations
import abc
import datetime
import functools
import warnings
from abc import abstractmethod
from typing import Any, ClassVar, Final, Self
from drytorch.core import exceptions, log_events
from drytorch.core import protocols as p
from drytorch.utils import repr_utils
__all__ = [
'DEFAULT_TRACKERS',
'EventDispatcher',
'MetadataManager',
'Tracker',
'extend_default_trackers',
'remove_all_default_trackers',
]
DEFAULT_TRACKERS: dict[str, Tracker] = {}
[docs]
class Tracker(metaclass=abc.ABCMeta):
"""Abstract base class for tracking events with priority ordering."""
_current: ClassVar[Self | None] = None
[docs]
@functools.singledispatchmethod
@abstractmethod
def notify(self, event: log_events.Event) -> None:
"""Notify the tracker of an event.
Args:
event: the event to notify about.
"""
return
@notify.register
def _(self, event: log_events.StartExperimentEvent) -> None:
_not_used = event
self._set_current(self)
return
@notify.register
def _(self, event: log_events.StopExperimentEvent) -> None:
_not_used = event
self._reset_current()
self.clean_up()
return
[docs]
def clean_up(self) -> None:
"""Override to clean up the tracker."""
return
[docs]
@classmethod
def get_current(cls) -> Self:
"""Get the registered tracker that is already registered.
Returns:
The instance of the tracker registered to the current experiment.
Raises:
TrackerNotActiveError: if the tracker is not registered.
"""
if cls._current is None:
raise exceptions.TrackerNotUsedError(cls.__name__)
return cls._current
@classmethod
def _set_current(cls, tracker: Self) -> None:
cls._current = tracker
return
@classmethod
def _reset_current(cls) -> None:
cls._current = None
return
[docs]
class EventDispatcher:
"""Notifies tracker of an event.
Attributes:
exp_name: name of the current experiment.
named_trackers: a dictionary of trackers, indexed by their names.
"""
exp_name: str
named_trackers: dict[str, Tracker]
def __init__(self, exp_name) -> None:
"""Initialize.
Args:
exp_name: name of the current experiment.
"""
self.exp_name: Final = str(exp_name)
self.named_trackers: Final = dict[str, Tracker]()
return
[docs]
def publish(self, event: log_events.Event) -> None:
"""Publish an event to all registered trackers.
Args:
event: the event to publish.
Raises:
KeyboardInterrupt: if a tracker raises KeyboardInterrupt.
SystemExit: if a tracker raises SystemExit.
"""
to_be_removed = list[str]()
for name, tracker in self.named_trackers.items():
try:
tracker.notify(event)
except (KeyboardInterrupt, SystemExit) as se:
raise se
except Exception as err:
warnings.warn(
exceptions.TrackerExceptionWarning(name, err), stacklevel=1
)
to_be_removed.append(name)
for name in to_be_removed:
tracker = self.named_trackers[name]
try:
tracker.clean_up()
except Exception as err:
warnings.warn(
exceptions.TrackerExceptionWarning(name, err),
stacklevel=1,
)
self.remove(name)
return
def _subscribe_tracker(self, name: str, tracker: Tracker) -> None:
"""Subscribe a tracker to the dispatcher.
Args:
name: the name associated with the tracker.
tracker: the tracker to register.
Raises:
TrackerAlreadyRegisteredError: if the tracker is already registered.
"""
if name in self.named_trackers:
raise exceptions.TrackerAlreadyRegisteredError(name, self.exp_name)
self.named_trackers[name] = tracker
return
[docs]
def subscribe(self, *trackers: Tracker, **named_trackers: Tracker) -> None:
"""Subscribe trackers to the dispatcher.
Args:
trackers: trackers to register with their class names.
named_trackers: trackers to register with custom names.
Raises:
TrackerAlreadyRegisteredError: if a tracker is already registered.
"""
for tracker in trackers:
name = tracker.__class__.__name__
self._subscribe_tracker(name, tracker)
for name, tracker in named_trackers.items():
self._subscribe_tracker(name, tracker)
return
[docs]
def remove(self, tracker_name: str) -> None:
"""Remove a tracker by name from the dispatcher.
Args:
tracker_name: name of the tracker to remove.
Raises:
Raises:
TrackerNotActiveError: if the tracker is not registered.
"""
try:
self.named_trackers.pop(tracker_name)
except KeyError as ke:
raise exceptions.TrackerNotUsedError(tracker_name) from ke
return
[docs]
def remove_all(self) -> None:
"""Remove all trackers from the dispatcher."""
for tracker_name in list(self.named_trackers):
self.remove(tracker_name)
return
[docs]
def extend_default_trackers(tracker_list: list[Tracker]) -> None:
"""Add a list of trackers to the default ones."""
for tracker in tracker_list:
DEFAULT_TRACKERS[tracker.__class__.__name__] = tracker
return
[docs]
def remove_all_default_trackers() -> None:
"""Remove all default trackers."""
DEFAULT_TRACKERS.clear()
return