Source code for drytorch.core.experimenting

"""Module containing the Experiment and Run class."""

from __future__ import annotations

import dataclasses
import gc
import json
import multiprocessing
import pathlib
import shutil
import subprocess
import types
import warnings
import weakref

from typing import (
    Any,
    ClassVar,
    Final,
    Generic,
    Literal,
    Self,
    TypeVar,
)

import filelock

from torch import distributed as dist
from typing_extensions import override

from drytorch.core import exceptions, log_events, tracking
from drytorch.utils import repr_utils


__all__ = [
    'Experiment',
    'Run',
]

_T_co = TypeVar('_T_co', covariant=True)
RunStatus = Literal['created', 'running', 'completed', 'failed']


@dataclasses.dataclass
class RunMetadata:
    """Metadata for a run."""

    id: str
    status: RunStatus
    timestamp: str
    commit: str | None


class RunRegistry:
    """Creates and manages a JSON file for run metadata.

    Attributes:
        file_path: path to the JSON file.
    """

    file_path: pathlib.Path
    lock_path: pathlib.Path

    def __init__(self, path: pathlib.Path):
        """Initialize.

        Args:
            path: path to the JSON file.
        """
        self.file_path = path
        self.lock_path = path.with_suffix('.json.lock')
        return

    def load_all(self) -> list[RunMetadata]:
        """Loads all run metadata from a JSON file."""
        if not self.file_path.exists():
            return []

        try:
            with self.file_path.open() as f:
                data = json.load(f)
        except (FileNotFoundError, json.JSONDecodeError):
            return []

        run_data = []
        for item in data:
            run_data.append(RunMetadata(**item))

        return run_data

    def register_new_run(self, run_metadata: RunMetadata) -> None:
        """Register a new run, ensuring a unique ID.

        Args:
            run_metadata: the metadata for the run (id will be updated).
        """
        self.file_path.parent.mkdir(parents=True, exist_ok=True)
        with filelock.FileLock(self.lock_path):
            run_data = self.load_all()
            run_data.append(run_metadata)

            # Convert to dicts for JSON serialization
            serialized_data = [dataclasses.asdict(r) for r in run_data]
            self.file_path.write_text(json.dumps(serialized_data, indent=2))

        return

    def update_run_status(self, run_id: str, status: RunStatus) -> None:
        """Update the status of a run.

        Args:
            run_id: the id of the run to update.
            status: the new status.
        """
        with filelock.FileLock(self.lock_path):
            run_data = self.load_all()
            for run in run_data:
                if run.id == run_id:
                    run.status = status
                    break
            else:
                raise exceptions.RunNotRecordedError(run_id)

            # Convert to dicts for JSON serialization
            serialized_data = [dataclasses.asdict(r) for r in run_data]
            self.file_path.write_text(json.dumps(serialized_data, indent=2))
        return


[docs] class Experiment(Generic[_T_co]): """Manage experiment configuration, directory, and tracking. This class associates a configuration file, a name, and a working directory with a machine learning experiment. It also contains the trackers responsible for tracking the metadata and metrics for the experiment. Finally, it allows global access to a configuration file with the correct type annotations. Attributes: folder_name: name of the hidden folder storing experiment metadata. run_file: filename storing the registry of run IDs for this experiment. previous_runs: a list of all previous runs created by this class. par_dir: parent directory for experiment data. tags: descriptors for the experiment. trackers: dispatcher for publishing events. """ folder_name: ClassVar[str] = '.drytorch' run_file: ClassVar[str] = 'runs.json' previous_runs: ClassVar[list[Run]] = [] _name = repr_utils.DefaultName() __current: ClassVar[Experiment[Any] | None] = None par_dir: pathlib.Path __config: _T_co tags: list[str] trackers: tracking.EventDispatcher _registry: RunRegistry _active_run: Run[_T_co] | None def __init__( self, config: _T_co, *, name: str = '', par_dir: str | pathlib.Path = pathlib.Path(), tags: list[str] | None = None, ) -> None: """Initialize. Args: config: Configuration for the experiment. name: The name of the experiment (defaults to class name). par_dir: Parent directory for experiment data. tags: Descriptors for the experiment (e.g., ``"lr=0.01"``). """ _validate_chars(name) self.__config: Final = config self._name = name self.par_dir = pathlib.Path(par_dir) self.tags = tags or [] self.trackers: Final = tracking.EventDispatcher(self.name) self.trackers.subscribe(**tracking.DEFAULT_TRACKERS) run_file = self.par_dir / self.folder_name / self.name / self.run_file self._registry = RunRegistry(run_file) self._active_run = None return @property def name(self) -> str: """The name of the experiment.""" return self._name @property def config(self) -> _T_co: """Experiment configuration.""" return self.__config
[docs] def create_run( self, *, run_id: str | None = None, resume: bool = False, record: bool = True, ) -> Run[_T_co]: """Convenience constructor for a Run using this experiment. Args: run_id: identifier of the run; defaults to timestamp. resume: resume the selected run if run_id is set, else the last run. record: record the run in the registry. Returns: The created run object. Raises: RunAlreadyRecordedError: if creating a new run with an existing id. """ if run_id is not None: _validate_chars(run_id) runs_data = self._registry.load_all() if resume: return self._handle_resume_logic(run_id, runs_data, record) if runs_data and run_id in [r.id for r in runs_data]: raise exceptions.RunAlreadyRecordedError(run_id, self.name) return self._create_new_run(run_id, record)
def _handle_resume_logic( self, run_id: str | None, runs_data: list[RunMetadata], record: bool ) -> Run[_T_co]: """Handle resume logic for existing runs.""" if self.previous_runs: run = self._get_run_from_previous(run_id) if run: run.resumed = True run.status = 'created' return run if not runs_data: warnings.warn(exceptions.NoPreviousRunsWarning(), stacklevel=2) return self._create_new_run(run_id, record) if run_id is None: run_id = runs_data[-1].id else: matching_runs = [r for r in runs_data if r.id == run_id] if not matching_runs: warnings.warn( exceptions.NotExistingRunWarning(run_id), stacklevel=1 ) return self._create_new_run(run_id, record) if len(matching_runs) > 1: msg = f'Multiple runs with id {run_id} found in the registry.' raise RuntimeError(msg) return Run(experiment=self, run_id=run_id, resumed=True, record=record) def _get_run_from_previous(self, run_id: str | None) -> Run[_T_co] | None: """Get run from the previous_runs list.""" if run_id is None: return self.previous_runs[-1] matching_runs = [r for r in self.previous_runs if r.id == run_id] if not matching_runs: return None matching_run, *other_runs = matching_runs if other_runs: msg = f'Multiple runs with id {run_id} found for exp {self.name}' raise RuntimeError(msg) return matching_run def _create_new_run( self, run_id: str | None, record: bool, ) -> Run[_T_co]: """Create a new run (non-resume case).""" run = Run(experiment=self, run_id=run_id, record=record) run_data = RunMetadata( id=run.id, status='created', timestamp=run.created_at_str, commit=self._get_last_commit_hash(), ) if run.record: self._registry.register_new_run(run_data) return run @property def run(self) -> Run[_T_co]: """Get the current run. Raises: NoActiveExperimentError: if no run is currently active. """ if self._active_run is None: raise exceptions.NoActiveExperimentError(self.name) return self._active_run @run.setter def run(self, current_run: Run[_T_co]) -> None: self._active_run = current_run return
[docs] @classmethod def get_config(cls) -> _T_co: """Retrieve the configuration of the current experiment.""" return cls.get_current().__config
[docs] @classmethod def get_current(cls) -> Self: """Return the currently active experiment. Raises: NoActiveExperimentError: if no experiment is currently active. """ if Experiment.__current is None: raise exceptions.NoActiveExperimentError() if not isinstance(Experiment.__current, cls): raise exceptions.NoActiveExperimentError(experiment_class=cls) return Experiment.__current
[docs] @staticmethod def set_current(experiment: Experiment[_T_co]) -> None: """Set an experiment as active. Raises: NestedScopeError: if there is an already active run. """ if (old_exp := Experiment.__current) is not None: raise exceptions.NestedScopeError(old_exp.name, experiment.name) Experiment.__current = experiment return
@staticmethod def _clear_current() -> None: """Clear the active experiment.""" Experiment.__current = None return @staticmethod def _get_last_commit_hash() -> str | None: """Get the last commit hash if available otherwise None.""" git = shutil.which('git') if git is None: return None try: result = subprocess.run( # noqa: S603 [git, 'rev-parse', 'HEAD'], capture_output=True, text=True, check=True, ) except subprocess.CalledProcessError: return None return result.stdout.strip() @override def __repr__(self) -> str: return f'{self.__class__.__name__}(name={self.name})'
[docs] class Run(repr_utils.CreatedAtMixin, Generic[_T_co]): """Execution lifecycle for a single run of an Experiment. Attributes: status: Current status of the run. resumed: whether the run was resumed. metadata_manager: Manager for run metadata. record: whether to record the run in the registry. """ _experiment: Experiment[_T_co] _is_distributed: bool _is_main_process: bool _id: str resumed: bool record: bool status: RunStatus metadata_manager: tracking.MetadataManager _finalizer: weakref.finalize[..., Self] | None def __init__( self, experiment: Experiment[_T_co], run_id: str | None, resumed: bool = False, record: bool = True, ) -> None: """Initialize. Args: experiment: the experiment this run belongs to. run_id: identifier of the run. resumed: whether the run was resumed. record: record the run in the registry. """ super().__init__() self._experiment: Final = experiment self._is_distributed = dist.is_available() and dist.is_initialized() self._is_main_process = not self._is_distributed or dist.get_rank() == 0 self._id: Final = self._get_run_id(run_id) self.resumed = resumed self.record = record and self._is_main_process self.status = 'created' self.metadata_manager: Final = tracking.MetadataManager() self._finalizer = None if not self.resumed: experiment.previous_runs.append(self) if self._is_distributed: feature = 'Data-distributed support' warnings.warn( exceptions.ExperimentalFeatureWarning(feature), stacklevel=2 ) return @property def experiment(self) -> Experiment[_T_co]: """The experiment this run belongs to.""" return self._experiment @property def id(self) -> str: """The identifier of the run.""" return self._id
[docs] def __enter__(self) -> Self: """Enter the experiment scope.""" self.start() return self
[docs] def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: types.TracebackType | None, ) -> None: """Exit the experiment scope.""" if exc_type is not None: self.status = 'failed' self.stop() return
[docs] def is_active(self) -> bool: """Check if the run is currently active.""" return self.status == 'running'
[docs] def stop(self) -> None: """Stop the experiment scope.""" if self.status == 'running': # failed is left as is self.status = 'completed' elif self.status == 'completed': warnings.warn(exceptions.RunAlreadyCompletedWarning(), stacklevel=1) return if self.status == 'created': warnings.warn(exceptions.RunNotStartedWarning(), stacklevel=1) return if self.record: self._update_registry() if self._finalizer is not None: self._finalizer.detach() self._finalizer = None self._stop_experiment(self.experiment, self._id) return
[docs] def start(self: Self) -> None: """Start the experiment scope.""" if self.status == 'running': warnings.warn(exceptions.RunAlreadyRunningWarning(), stacklevel=1) return self._finalizer = weakref.finalize( self, self._stop_experiment, self._experiment, self._id ) self.status = 'running' if self.record: self._update_registry() self._experiment._active_run = self Experiment.set_current(self._experiment) if self._is_main_process: log_events.Event.set_auto_publish(self._experiment.trackers.publish) else: # no tracking in secondary processes log_events.Event.set_auto_publish(lambda _: None) log_events.StartExperimentEvent( self._experiment.config, self._experiment.name, self.created_at, self._id, self.resumed, self._experiment.par_dir, self._experiment.tags, ) return
def _get_run_id(self, run_id: str | None) -> str: """Generate a run ID, appending PID if in a worker process.""" final_id = run_id or self.created_at_str if not self._is_distributed: # keep the same ID for distributed runs if multiprocessing.current_process().name != 'MainProcess': final_id = f'{final_id}_{multiprocessing.current_process().pid}' return final_id def _update_registry(self) -> None: """Update the run status in the experiment's registry.""" self.experiment._registry.update_run_status(self.id, self.status) return @staticmethod def _stop_experiment(experiment: Experiment[_T_co], run_id: str) -> None: """Cleanup without holding reference to a Run instance.""" log_events.StopExperimentEvent(experiment.name, run_id) log_events.Event.set_auto_publish(None) experiment._active_run = None Experiment._clear_current() gc.collect() return @override def __repr__(self) -> str: return f'{self.__class__.__name__}(id={self.id}, status={self.status})'
def _validate_chars(name: str) -> None: if len(name) > 255: msg = f'Name is too long (max 255 chars): {len(name)}' raise ValueError(msg) not_allowed_chars = set(r'\/:*?"<>|') if invalid_chars := set(name) & not_allowed_chars: msg = f'Name contains invalid character(s): {invalid_chars!r}' raise ValueError(msg)