Source code for drytorch.trackers.csv

"""Module containing a tracker that dumps metrics in a CSV file."""

import csv
import functools
import pathlib

from typing import ClassVar, Final

from typing_extensions import override

from drytorch.core import exceptions, log_events
from drytorch.trackers import base_classes


__all__ = [
    'CSVDumper',
]


class DryTorchDialect(csv.Dialect):
    """Dialect similar to excel that converts numbers to floats.

    Attributes:
        delimiter: delimiter character.
        quotechar: quote character.
        doublequote: whether to enable double quoting.
        skipinitialspace: whether to skip initial whitespace.
        lineterminator: line terminator.
        quoting: quoting style.
    """

    delimiter: ClassVar[str] = ','
    quotechar: ClassVar[str] = '"'
    doublequote: ClassVar[bool] = True
    skipinitialspace: ClassVar[bool] = False
    lineterminator: ClassVar[str] = '\r\n'
    quoting: ClassVar[int] = csv.QUOTE_NONNUMERIC


[docs] class CSVDumper(base_classes.Dumper, base_classes.MetricLoader): """Dump metrics into a CSV file.""" folder_name = 'csv_metrics' _default_dialect: ClassVar[csv.Dialect] = DryTorchDialect() _base_headers: ClassVar[tuple[str, ...]] = ('Model', 'Source', 'Epoch') _active_sources: set[str] _dialect: csv.Dialect _resume_run: bool def __init__( self, par_dir: pathlib.Path | None = None, dialect: csv.Dialect = _default_dialect, ) -> None: """Initialize. Args: par_dir: the parent directory for the tracker data. Default uses the same of the current experiment. dialect: the format specification. Defaults to local dialect. """ super().__init__(par_dir) self._active_sources: Final = set() self._dialect = dialect self._resume_run = False return
[docs] @functools.singledispatchmethod @override def notify(self, event: log_events.Event) -> None: return super().notify(event)
@notify.register def _(self, event: log_events.StartExperimentEvent) -> None: self._resume_run = event.resumed return super().notify(event) @notify.register def _(self, event: log_events.MetricEvent) -> None: run_dir = self._get_run_dir() file_address = self._file_path( run_dir, event.model_name, event.source_name ) metric_names = tuple(event.metrics) headers = self._base_headers + metric_names if event.source_name not in self._active_sources: if self._resume_run and file_address.exists(): with file_address.open(newline='') as log: reader = csv.reader(log, dialect=self._dialect) previous_headers = tuple(next(reader)) if headers != previous_headers: msg = ( f'Current {headers=} and previous headers=' f'{previous_headers} do not correspond' ) raise exceptions.TrackerError(self, msg) else: with file_address.open('w', newline='') as log: # reset writer = csv.writer(log, dialect=self._dialect) writer.writerow(headers) self._active_sources.add(event.source_name) with file_address.open('a', newline='') as log: writer = csv.writer(log, dialect=self._dialect) writer.writerow( [ event.model_name, event.source_name, event.epoch, *event.metrics.values(), ] ) return super().notify(event)
[docs] def read_csv( self, model_name: str, source: str, max_epoch: int = -1, ) -> base_classes.HistoryMetrics: """Read the CSV file associated with the given model and source. Args: model_name: the name of the model. source: the source of the metrics. max_epoch: the maximum number of epochs to load. Defaults to all. Returns: Epochs and relative value for each metric. """ run_dir = self._get_run_dir() file_address = self._file_path(run_dir, model_name, source) with file_address.open(newline='') as log: reader = csv.reader(log, dialect=self._dialect) headers = next(reader) len_base = len(self._base_headers) metric_names = headers[len_base:] epochs = list[int]() named_metric_values = dict[str, list[float]]() epoch_column = self._base_headers.index('Epoch') for row in reader: epoch = int(row[epoch_column]) if epochs and epochs[-1] >= epoch: # only load the last run epochs.clear() named_metric_values.clear() if max_epoch != -1 and epoch > max_epoch: continue epochs.append(epoch) for metric, value in zip( metric_names, row[len_base:], strict=True ): value_list = named_metric_values.setdefault(metric, []) value_list.append(float(value)) return epochs, named_metric_values
[docs] @override def clean_up(self) -> None: self._resume_run = False return super().clean_up()
def _find_sources(self, model_name: str) -> set[str]: path = self._get_run_dir() / model_name named_sources = {file.stem for file in path.glob('*.csv')} if not named_sources: msg = f'No sources in {path}.' raise exceptions.TrackerError(self, msg) return named_sources @staticmethod def _file_path( run_dir: pathlib.Path, model_name: str, source_name: str, ) -> pathlib.Path: model_path = run_dir / model_name model_path.mkdir(exist_ok=True) return model_path / f'{source_name}.csv' def _load_metrics( self, model_name: str, max_epoch: int = -1 ) -> base_classes.SourcedMetrics: out: base_classes.SourcedMetrics = {} if self._resume_run: sources = self._find_sources(model_name) elif self._active_sources: sources = self._active_sources else: sources = set() for source in sources: out[source] = self.read_csv(model_name, source, max_epoch) return out