Source code for drytorch.trackers.sqlalchemy

"""Module containing sqlalchemy Table classes and a tracker to track metrics."""

from __future__ import annotations

import datetime
import functools

from typing import ClassVar, cast

import sqlalchemy

from sqlalchemy import orm, sql
from typing_extensions import override

from drytorch.core import exceptions, log_events
from drytorch.trackers import base_classes


__all__ = [
    'Experiment',
    'Log',
    'Run',
    'SQLConnection',
    'Source',
    'Tags',
]


reg = orm.registry()


[docs] @reg.mapped_as_dataclass class Experiment: """Table for experiments. Attributes: row_id: the unique id for the table. experiment_name: the experiment's name. runs: the entry for the run for the experiment. tags: the list of tags for the experiment. """ __tablename__ = 'experiments' row_id: orm.Mapped[int] = orm.mapped_column( init=False, primary_key=True, autoincrement=True, ) experiment_name: orm.Mapped[str] = orm.mapped_column(index=True) tags: orm.Mapped[list[Tags]] = orm.relationship( init=False, cascade='all, delete-orphan', ) runs: orm.Mapped[list[Run]] = orm.relationship( init=False, cascade='all, delete-orphan', )
[docs] @reg.mapped_as_dataclass class Tags: """Table for tags for experiments. Attributes: row_id: the unique id for the table. experiment_id: the id for the experiment. text: a tag for the experiment. """ __tablename__ = 'tags' row_id: orm.Mapped[int] = orm.mapped_column( init=False, primary_key=True, autoincrement=True, ) experiment_id: orm.Mapped[int] = orm.mapped_column( sqlalchemy.ForeignKey(Experiment.row_id), init=False, ) experiment: orm.Mapped[Experiment] = orm.relationship( back_populates=Experiment.tags.key ) text: orm.Mapped[str] = orm.mapped_column(index=True)
[docs] @reg.mapped_as_dataclass class Run: """Table for runs. A new run is created for each experiment scope, unless specified. Attributes: row_id: the unique id for the table. run_id: global identifier for the run. run_ts: the run's timestamp. experiment_id: the id of the experiment for the run. experiment: the entry for the experiment for the run. sources: the list of sources from experiments """ __tablename__ = 'runs' row_id: orm.Mapped[int] = orm.mapped_column( init=False, primary_key=True, autoincrement=True, ) run_id: orm.Mapped[str] = orm.mapped_column(index=True) run_ts: orm.Mapped[datetime.datetime] = orm.mapped_column() experiment_id: orm.Mapped[int] = orm.mapped_column( sqlalchemy.ForeignKey(Experiment.row_id), init=False, ) experiment: orm.Mapped[Experiment] = orm.relationship( back_populates=Experiment.runs.key ) sources: orm.Mapped[list[Source]] = orm.relationship( init=False, cascade='all, delete-orphan', )
[docs] @reg.mapped_as_dataclass class Source: """Table for sources. Attributes: row_id: the unique id for the table. model_name: the model's name. model_ts: the model's timestamp. source_name: the source's name. source_ts: the source's timestamp. run_table_id: the table id for the current experiment's run. run: the entry for the current experiment's run. logs: the list of logs originating from the source. """ __tablename__ = 'sources' row_id: orm.Mapped[int] = orm.mapped_column( init=False, primary_key=True, autoincrement=True, ) model_name: orm.Mapped[str] = orm.mapped_column(index=True) model_ts: orm.Mapped[datetime.datetime] = orm.mapped_column() source_name: orm.Mapped[str] = orm.mapped_column(index=True) source_ts: orm.Mapped[datetime.datetime] = orm.mapped_column() run_table_id: orm.Mapped[int] = orm.mapped_column( sqlalchemy.ForeignKey(Run.row_id), init=False, ) run: orm.Mapped[Run] = orm.relationship(back_populates=Run.sources.key) logs: orm.Mapped[list[Log]] = orm.relationship( init=False, cascade='all, delete-orphan', )
[docs] @reg.mapped_as_dataclass class Log: """Table for the logs of the metrics. Attributes: row_id: the unique id for the table. source_id: the id of the source creating the log. source: the entry for the source creating the log. epoch: the number of epochs the model has been trained. metric_name: the name of the metric. value: the value of the metric. created_at: the timestamp for the entry creation. """ __tablename__ = 'logs' row_id: orm.Mapped[int] = orm.mapped_column( init=False, primary_key=True, autoincrement=True, ) source_id: orm.Mapped[int] = orm.mapped_column( sqlalchemy.ForeignKey(Source.row_id), init=False, ) source: orm.Mapped[Source] = orm.relationship( back_populates=Source.logs.key, ) epoch: orm.Mapped[int] = orm.mapped_column(index=True) metric_name: orm.Mapped[str] = orm.mapped_column(index=True) value: orm.Mapped[float] = orm.mapped_column() created_at: orm.Mapped[datetime.datetime] = orm.mapped_column( init=False, insert_default=sql.func.now(), )
[docs] class SQLConnection(base_classes.MetricLoader): """Tracker that creates a connection to a SQL database using sqlalchemy. Attributes: default_url: by default, it creates a local sqlite database. engine: the sqlalchemy Engine for the connection. session_factory: the Session class to initiate a sqlalchemy session. """ default_url: ClassVar[sqlalchemy.URL] = sqlalchemy.URL.create( 'sqlite', database='metrics.db' ) engine: sqlalchemy.Engine session_factory: orm.sessionmaker[orm.Session] _run: Run | None _sources: dict[str, Source] def __init__( self, engine: sqlalchemy.Engine | None = None, ) -> None: """Initialize. Args: engine: the engine for the session. Default uses default_url. """ super().__init__() self.engine = engine or sqlalchemy.create_engine(self.default_url) reg.metadata.create_all(bind=self.engine) self.session_factory = orm.sessionmaker(bind=self.engine) self._run = None self._sources = {} return @property def run(self) -> Run: """The current run. Raises: AccessOutsideScopeError: if there is no active run. """ if self._run is None: raise exceptions.AccessOutsideScopeError() return self._run
[docs] @override def clean_up(self) -> None: self._run = None self._sources.clear() self.engine.dispose() return super().clean_up()
[docs] @functools.singledispatchmethod @override def notify(self, event: log_events.Event) -> None: return super().notify(event)
@notify.register def _(self, event: log_events.StartExperimentEvent) -> None: with self.session_factory() as session: experiment = Experiment( experiment_name=event.exp_name, ) for tag_str in event.tags: tag = Tags(text=tag_str, experiment=experiment) session.add(tag) self._run = Run(event.run_id, event.run_ts, experiment) session.add(experiment) session.commit() return super().notify(event) @notify.register def _(self, event: log_events.ActorRegistrationEvent) -> None: with self.session_factory() as session: run = session.merge(self.run) source = Source( model_name=event.model_name, model_ts=event.model_ts, source_name=event.actor_name, source_ts=event.actor_ts, run=run, ) session.add(source) session.commit() self._sources[event.actor_name] = source return super().notify(event) @notify.register def _(self, event: log_events.MetricEvent) -> None: """Process metric events. Raises: TrackerError: if the source has not been registered. """ with self.session_factory() as session: if event.source_name not in self._sources: msg = f'Source {event.source_name} has not been registered.' raise exceptions.TrackerError(self, msg) source = session.merge(self._sources[event.source_name]) for metric_name, value in event.metrics.items(): new_row = Log( source=source, epoch=event.epoch, metric_name=metric_name, value=value, ) session.add(new_row) session.commit() return super().notify(event) def _find_sources(self, model_name: str) -> dict[str, list[Source]]: with self.session_factory() as session: run = session.merge(self.run) query = ( session.query(Source) .join(Source.run) .where( Run.run_id.is_(run.run_id), Source.model_name.is_(model_name), ) ) named_sources = dict[str, list[Source]]() for source in query: source = cast(Source, source) # fixing wrong annotation sources = named_sources.setdefault(source.source_name, []) sources.append(source) if not named_sources: msg = f'No sources for model {model_name}.' raise exceptions.TrackerError(self, msg) return named_sources def _get_run_metrics( self, sources: list[Source], max_epoch: int, ) -> base_classes.HistoryMetrics: with self.session_factory() as session: sources = [session.merge(source) for source in sources] query = session.query(Log).where( Log.source_id.in_(source.row_id for source in sources), ) if max_epoch != -1: query = query.where(Log.epoch <= max_epoch) named_epochs = dict[str, list[int]]() named_metric_values = dict[str, list[float]]() for log in query: log = cast(Log, log) # fixing wrong annotation epochs = named_epochs.setdefault(log.metric_name, []) epochs.append(log.epoch) values = named_metric_values.setdefault(log.metric_name, []) values.append(log.value) name = '' epochs = list[int]() for next_name, next_epochs in named_epochs.items(): if epochs and epochs != next_epochs: msg = f'{name} and {next_name} logs refer to different epochs.' raise exceptions.TrackerError(self, msg) epochs = next_epochs name = next_name return epochs, named_metric_values def _load_metrics( self, model_name: str, max_epoch: int = -1 ) -> base_classes.SourcedMetrics: last_sources = self._find_sources(model_name) out: base_classes.SourcedMetrics = {} for source_name, run_sources in last_sources.items(): out[source_name] = self._get_run_metrics( run_sources, max_epoch=max_epoch ) return out