"""Module containing the Experiment and Run class."""
from __future__ import annotations
import dataclasses
import gc
import json
import multiprocessing
import pathlib
import shutil
import subprocess
import types
import warnings
import weakref
from typing import (
Any,
ClassVar,
Final,
Generic,
Literal,
Self,
TypeVar,
)
import filelock
from torch import distributed as dist
from typing_extensions import override
from drytorch.core import exceptions, log_events, tracking
from drytorch.utils import repr_utils
__all__ = [
'Experiment',
'Run',
]
_T_co = TypeVar('_T_co', covariant=True)
RunStatus = Literal['created', 'running', 'completed', 'failed']
@dataclasses.dataclass
class RunMetadata:
"""Metadata for a run."""
id: str
status: RunStatus
timestamp: str
commit: str | None
class RunRegistry:
"""Creates and manages a JSON file for run metadata.
Attributes:
file_path: path to the JSON file.
"""
file_path: pathlib.Path
lock_path: pathlib.Path
def __init__(self, path: pathlib.Path):
"""Initialize.
Args:
path: path to the JSON file.
"""
self.file_path = path
self.lock_path = path.with_suffix('.json.lock')
return
def load_all(self) -> list[RunMetadata]:
"""Loads all run metadata from a JSON file."""
if not self.file_path.exists():
return []
try:
with self.file_path.open() as f:
data = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
return []
run_data = []
for item in data:
run_data.append(RunMetadata(**item))
return run_data
def register_new_run(self, run_metadata: RunMetadata) -> None:
"""Register a new run, ensuring a unique ID.
Args:
run_metadata: the metadata for the run (id will be updated).
"""
self.file_path.parent.mkdir(parents=True, exist_ok=True)
with filelock.FileLock(self.lock_path):
run_data = self.load_all()
run_data.append(run_metadata)
# Convert to dicts for JSON serialization
serialized_data = [dataclasses.asdict(r) for r in run_data]
self.file_path.write_text(json.dumps(serialized_data, indent=2))
return
def update_run_status(self, run_id: str, status: RunStatus) -> None:
"""Update the status of a run.
Args:
run_id: the id of the run to update.
status: the new status.
"""
with filelock.FileLock(self.lock_path):
run_data = self.load_all()
for run in run_data:
if run.id == run_id:
run.status = status
break
else:
raise exceptions.RunNotRecordedError(run_id)
# Convert to dicts for JSON serialization
serialized_data = [dataclasses.asdict(r) for r in run_data]
self.file_path.write_text(json.dumps(serialized_data, indent=2))
return
[docs]
class Experiment(Generic[_T_co]):
"""Manage experiment configuration, directory, and tracking.
This class associates a configuration file, a name, and a working directory
with a machine learning experiment. It also contains the trackers
responsible for tracking the metadata and metrics for the experiment.
Finally, it allows global access to a configuration file with the correct
type annotations.
Attributes:
folder_name: name of the hidden folder storing experiment metadata.
run_file: filename storing the registry of run IDs for this experiment.
previous_runs: a list of all previous runs created by this class.
par_dir: parent directory for experiment data.
tags: descriptors for the experiment.
trackers: dispatcher for publishing events.
"""
folder_name: ClassVar[str] = '.drytorch'
run_file: ClassVar[str] = 'runs.json'
previous_runs: ClassVar[list[Run]] = []
_name = repr_utils.DefaultName()
__current: ClassVar[Experiment[Any] | None] = None
par_dir: pathlib.Path
__config: _T_co
tags: list[str]
trackers: tracking.EventDispatcher
_registry: RunRegistry
_active_run: Run[_T_co] | None
def __init__(
self,
config: _T_co,
*,
name: str = '',
par_dir: str | pathlib.Path = pathlib.Path(),
tags: list[str] | None = None,
) -> None:
"""Initialize.
Args:
config: Configuration for the experiment.
name: The name of the experiment (defaults to class name).
par_dir: Parent directory for experiment data.
tags: Descriptors for the experiment (e.g., ``"lr=0.01"``).
"""
_validate_chars(name)
self.__config: Final = config
self._name = name
self.par_dir = pathlib.Path(par_dir)
self.tags = tags or []
self.trackers: Final = tracking.EventDispatcher(self.name)
self.trackers.subscribe(**tracking.DEFAULT_TRACKERS)
run_file = self.par_dir / self.folder_name / self.name / self.run_file
self._registry = RunRegistry(run_file)
self._active_run = None
return
@property
def name(self) -> str:
"""The name of the experiment."""
return self._name
@property
def config(self) -> _T_co:
"""Experiment configuration."""
return self.__config
[docs]
def create_run(
self,
*,
run_id: str | None = None,
resume: bool = False,
record: bool = True,
) -> Run[_T_co]:
"""Convenience constructor for a Run using this experiment.
Args:
run_id: identifier of the run; defaults to timestamp.
resume: resume the selected run if run_id is set, else the last run.
record: record the run in the registry.
Returns:
The created run object.
Raises:
RunAlreadyRecordedError: if creating a new run with an existing id.
"""
if run_id is not None:
_validate_chars(run_id)
runs_data = self._registry.load_all()
if resume:
return self._handle_resume_logic(run_id, runs_data, record)
if runs_data and run_id in [r.id for r in runs_data]:
raise exceptions.RunAlreadyRecordedError(run_id, self.name)
return self._create_new_run(run_id, record)
def _handle_resume_logic(
self, run_id: str | None, runs_data: list[RunMetadata], record: bool
) -> Run[_T_co]:
"""Handle resume logic for existing runs."""
if self.previous_runs:
run = self._get_run_from_previous(run_id)
if run:
run.resumed = True
run.status = 'created'
return run
if not runs_data:
warnings.warn(exceptions.NoPreviousRunsWarning(), stacklevel=2)
return self._create_new_run(run_id, record)
if run_id is None:
run_id = runs_data[-1].id
else:
matching_runs = [r for r in runs_data if r.id == run_id]
if not matching_runs:
warnings.warn(
exceptions.NotExistingRunWarning(run_id), stacklevel=1
)
return self._create_new_run(run_id, record)
if len(matching_runs) > 1:
msg = f'Multiple runs with id {run_id} found in the registry.'
raise RuntimeError(msg)
return Run(experiment=self, run_id=run_id, resumed=True, record=record)
def _get_run_from_previous(self, run_id: str | None) -> Run[_T_co] | None:
"""Get run from the previous_runs list."""
if run_id is None:
return self.previous_runs[-1]
matching_runs = [r for r in self.previous_runs if r.id == run_id]
if not matching_runs:
return None
matching_run, *other_runs = matching_runs
if other_runs:
msg = f'Multiple runs with id {run_id} found for exp {self.name}'
raise RuntimeError(msg)
return matching_run
def _create_new_run(
self,
run_id: str | None,
record: bool,
) -> Run[_T_co]:
"""Create a new run (non-resume case)."""
run = Run(experiment=self, run_id=run_id, record=record)
run_data = RunMetadata(
id=run.id,
status='created',
timestamp=run.created_at_str,
commit=self._get_last_commit_hash(),
)
if run.record:
self._registry.register_new_run(run_data)
return run
@property
def run(self) -> Run[_T_co]:
"""Get the current run.
Raises:
NoActiveExperimentError: if no run is currently active.
"""
if self._active_run is None:
raise exceptions.NoActiveExperimentError(self.name)
return self._active_run
@run.setter
def run(self, current_run: Run[_T_co]) -> None:
self._active_run = current_run
return
[docs]
@classmethod
def get_config(cls) -> _T_co:
"""Retrieve the configuration of the current experiment."""
return cls.get_current().__config
[docs]
@classmethod
def get_current(cls) -> Self:
"""Return the currently active experiment.
Raises:
NoActiveExperimentError: if no experiment is currently active.
"""
if Experiment.__current is None:
raise exceptions.NoActiveExperimentError()
if not isinstance(Experiment.__current, cls):
raise exceptions.NoActiveExperimentError(experiment_class=cls)
return Experiment.__current
[docs]
@staticmethod
def set_current(experiment: Experiment[_T_co]) -> None:
"""Set an experiment as active.
Raises:
NestedScopeError: if there is an already active run.
"""
if (old_exp := Experiment.__current) is not None:
raise exceptions.NestedScopeError(old_exp.name, experiment.name)
Experiment.__current = experiment
return
@staticmethod
def _clear_current() -> None:
"""Clear the active experiment."""
Experiment.__current = None
return
@staticmethod
def _get_last_commit_hash() -> str | None:
"""Get the last commit hash if available otherwise None."""
git = shutil.which('git')
if git is None:
return None
try:
result = subprocess.run( # noqa: S603
[git, 'rev-parse', 'HEAD'],
capture_output=True,
text=True,
check=True,
)
except subprocess.CalledProcessError:
return None
return result.stdout.strip()
@override
def __repr__(self) -> str:
return f'{self.__class__.__name__}(name={self.name})'
[docs]
class Run(repr_utils.CreatedAtMixin, Generic[_T_co]):
"""Execution lifecycle for a single run of an Experiment.
Attributes:
status: Current status of the run.
resumed: whether the run was resumed.
metadata_manager: Manager for run metadata.
record: whether to record the run in the registry.
"""
_experiment: Experiment[_T_co]
_is_distributed: bool
_is_main_process: bool
_id: str
resumed: bool
record: bool
status: RunStatus
metadata_manager: tracking.MetadataManager
_finalizer: weakref.finalize[..., Self] | None
def __init__(
self,
experiment: Experiment[_T_co],
run_id: str | None,
resumed: bool = False,
record: bool = True,
) -> None:
"""Initialize.
Args:
experiment: the experiment this run belongs to.
run_id: identifier of the run.
resumed: whether the run was resumed.
record: record the run in the registry.
"""
super().__init__()
self._experiment: Final = experiment
self._is_distributed = dist.is_available() and dist.is_initialized()
self._is_main_process = not self._is_distributed or dist.get_rank() == 0
self._id: Final = self._get_run_id(run_id)
self.resumed = resumed
self.record = record and self._is_main_process
self.status = 'created'
self.metadata_manager: Final = tracking.MetadataManager()
self._finalizer = None
if not self.resumed:
experiment.previous_runs.append(self)
if self._is_distributed:
feature = 'Data-distributed support'
warnings.warn(
exceptions.ExperimentalFeatureWarning(feature), stacklevel=2
)
return
@property
def experiment(self) -> Experiment[_T_co]:
"""The experiment this run belongs to."""
return self._experiment
@property
def id(self) -> str:
"""The identifier of the run."""
return self._id
[docs]
def __enter__(self) -> Self:
"""Enter the experiment scope."""
self.start()
return self
[docs]
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None:
"""Exit the experiment scope."""
if exc_type is not None:
self.status = 'failed'
self.stop()
return
[docs]
def is_active(self) -> bool:
"""Check if the run is currently active."""
return self.status == 'running'
[docs]
def stop(self) -> None:
"""Stop the experiment scope."""
if self.status == 'running': # failed is left as is
self.status = 'completed'
elif self.status == 'completed':
warnings.warn(exceptions.RunAlreadyCompletedWarning(), stacklevel=1)
return
if self.status == 'created':
warnings.warn(exceptions.RunNotStartedWarning(), stacklevel=1)
return
if self.record:
self._update_registry()
if self._finalizer is not None:
self._finalizer.detach()
self._finalizer = None
self._stop_experiment(self.experiment, self._id)
return
[docs]
def start(self: Self) -> None:
"""Start the experiment scope."""
if self.status == 'running':
warnings.warn(exceptions.RunAlreadyRunningWarning(), stacklevel=1)
return
self._finalizer = weakref.finalize(
self, self._stop_experiment, self._experiment, self._id
)
self.status = 'running'
if self.record:
self._update_registry()
self._experiment._active_run = self
Experiment.set_current(self._experiment)
if self._is_main_process:
log_events.Event.set_auto_publish(self._experiment.trackers.publish)
else: # no tracking in secondary processes
log_events.Event.set_auto_publish(lambda _: None)
log_events.StartExperimentEvent(
self._experiment.config,
self._experiment.name,
self.created_at,
self._id,
self.resumed,
self._experiment.par_dir,
self._experiment.tags,
)
return
def _get_run_id(self, run_id: str | None) -> str:
"""Generate a run ID, appending PID if in a worker process."""
final_id = run_id or self.created_at_str
if not self._is_distributed: # keep the same ID for distributed runs
if multiprocessing.current_process().name != 'MainProcess':
final_id = f'{final_id}_{multiprocessing.current_process().pid}'
return final_id
def _update_registry(self) -> None:
"""Update the run status in the experiment's registry."""
self.experiment._registry.update_run_status(self.id, self.status)
return
@staticmethod
def _stop_experiment(experiment: Experiment[_T_co], run_id: str) -> None:
"""Cleanup without holding reference to a Run instance."""
log_events.StopExperimentEvent(experiment.name, run_id)
log_events.Event.set_auto_publish(None)
experiment._active_run = None
Experiment._clear_current()
gc.collect()
return
@override
def __repr__(self) -> str:
return f'{self.__class__.__name__}(id={self.id}, status={self.status})'
def _validate_chars(name: str) -> None:
if len(name) > 255:
msg = f'Name is too long (max 255 chars): {len(name)}'
raise ValueError(msg)
not_allowed_chars = set(r'\/:*?"<>|')
if invalid_chars := set(name) & not_allowed_chars:
msg = f'Name contains invalid character(s): {invalid_chars!r}'
raise ValueError(msg)