Source code for drytorch.lib.runners

"""Module containing classes that run a model."""

import abc
import copy
import sys
import warnings

from collections.abc import Iterator, Mapping
from typing import (
    Any,
    ClassVar,
    Final,
    Generic,
    Protocol,
    TypeVar,
    runtime_checkable,
)

import torch

from torch import distributed as dist
from torch.utils import data
from typing_extensions import override

from drytorch.core import exceptions, log_events, registering
from drytorch.core import protocols as p
from drytorch.lib import loading, objectives
from drytorch.utils import apply_ops, repr_utils


__all__ = [
    'ModelCaller',
    'ModelRunner',
    'ModelRunnerWithLogs',
    'ModelRunnerWithObjective',
]


Input = TypeVar('Input', bound=p.InputType)
Target = TypeVar('Target', bound=p.TargetType)
Output = TypeVar('Output', bound=p.OutputType)
_Objective_co = TypeVar(
    '_Objective_co', bound=p.LossProtocol[Any, Any], covariant=True
)


@runtime_checkable
class SupportsSync(Protocol):
    """Protocol for objects that support syncing."""

    def sync(self) -> None:
        """Sync across all processes."""


[docs] class ModelCaller( repr_utils.CreatedAtMixin, Generic[Input, Output], metaclass=abc.ABCMeta ): """Base class that calls a model. Attributes: model: the wrapped model. """ _name = repr_utils.DefaultName() model: p.ModelProtocol[Input, Output] def __init__( self, model: p.ModelProtocol[Input, Output], name: str = '' ) -> None: """Initialize. Args: model: the wrapped model. name: the name for the object for logging purposes. Defaults to class name plus eventual counter. """ super().__init__() self.model: Final = model self._name = name return @property def name(self) -> str: """The name of the model.""" return self._name
[docs] @abc.abstractmethod def __call__(self) -> None: """Document itself when the model is first called.""" registering.register_actor(self, self.model) return
@override def __repr__(self) -> str: return f'{self.name}({self.model.name})'
[docs] class ModelRunner(ModelCaller[Input, Output], Generic[Input, Target, Output]): """Run a model on a dataset. Attributes: model: the model to run. loader: the loader providing inputs and targets in batches. outputs_list: list of optionally stored outputs. """ max_stored_output: ClassVar[int] = sys.maxsize loader: p.LoaderProtocol[tuple[Input, Target]] outputs_list: list[Output] _cached_metrics: Mapping[str, float] _is_distributed: bool _world_size: int _max_stored_process: int def __init__( self, model: p.ModelProtocol[Input, Output], name: str = '', *, loader: p.LoaderProtocol[tuple[Input, Target]], ) -> None: """Initialize. Args: model: the model to run. name: the name for the object for logging purposes. Defaults to class name plus eventual counter. loader: provides inputs and targets in batches. """ super().__init__(model, name) self.loader: Final = loader self.outputs_list: Final = list[Output]() self._cached_metrics = {} self._is_distributed = dist.is_available() and dist.is_initialized() self._world_size = dist.get_world_size() if self._is_distributed else 1 self._max_stored_process = self.max_stored_output // self._world_size return
[docs] def __call__(self, store_outputs: bool = False) -> None: """Single pass on the dataset. Args: store_outputs: whether to store model outputs. Defaults to False. """ super().__call__() self._run_epoch(store_outputs) return
@property def computed_metrics(self) -> Mapping[str, float]: """Retrieve cached metrics.""" return self._cached_metrics def _check_divisibility(self, n_samples: int) -> None: if torch.is_inference_mode_enabled() and self._is_distributed: if n_samples % self._world_size: warnings.warn( exceptions.DistributedDatasetNotDivisibleWarning( self.name, n_samples, self._world_size ), stacklevel=2, ) return def _compute_metrics(self) -> Mapping[str, float]: return {} def _get_batches(self) -> Iterator[tuple[Input, Target]]: if self._is_distributed: if isinstance(self.loader.sampler, data.DistributedSampler): self.loader.sampler.set_epoch(self.model.epoch) return ( apply_ops.apply_to(batch, self.model.device) for batch in self.loader ) def _run_backward(self, outputs: Output, targets: Target) -> None: _not_used = outputs, targets return def _run_batch( self, batch: tuple[Input, Target], ) -> Output: inputs, targets = batch outputs = self._run_forward(inputs) self._run_backward(outputs, targets) return outputs def _run_epoch(self, store_outputs: bool): if self._is_distributed: if not hasattr(self.model.module, 'module'): warnings.warn( exceptions.ModuleNotDistributedWarning(), stacklevel=2 ) self.outputs_list.clear() n_samples = loading.validate_dataset_length(self.loader.dataset) self._check_divisibility(n_samples) n_processes = dist.get_world_size() if self._is_distributed else 1 pbar = log_events.IterateBatchEvent( self.name, self.loader.batch_size, len(self.loader), n_samples ) for batch in self._get_batches(): outputs = self._run_batch(batch) metrics = self._compute_metrics() pbar.update(metrics, n_processes) if store_outputs: self._store(outputs) if self._is_distributed: self._sync() pbar.update(self._compute_metrics(), 0) # trigger sync self._gather_stored_outputs() return def _run_forward(self, inputs: Input) -> Output: return self.model(inputs) def _store(self, outputs: Output) -> None: try: outputs = apply_ops.apply_cpu_detach(outputs) except ( exceptions.FuncNotApplicableError, exceptions.NamedTupleOnlyError, ) as err: warnings.warn( exceptions.CannotStoreOutputWarning(err), stacklevel=3 ) else: if len(self.outputs_list) < self._max_stored_process: self.outputs_list.append(outputs) return def _sync(self) -> None: """Synchronize objective across processes.""" return None def _gather_stored_outputs(self) -> None: """Gather outputs from all processes to rank 0.""" if not self.outputs_list: return rank = dist.get_rank() try: if rank == 0: dist_outputs: list[list[Output]] = [[]] * self._world_size dist.gather_object(self.outputs_list, dist_outputs, dst=0) self.outputs_list.clear() for gathered_tuple in zip(*dist_outputs, strict=True): self.outputs_list.extend(gathered_tuple) else: dist.gather_object(self.outputs_list, None, dst=0) self.outputs_list.clear() except Exception as err: warnings.warn( exceptions.DistributedStorageWarning(err), stacklevel=2 ) return
[docs] class ModelRunnerWithObjective( ModelRunner[Input, Target, Output], p.MonitorProtocol, Generic[Input, Target, Output, _Objective_co], ): """Run a model on a dataset, calculating the value of an objective function. Attributes: model: the model containing the weights to evaluate. loader: provides inputs and targets in batches. objective: processes the model outputs and targets. outputs_list: list of optionally stored outputs. """ objective: _Objective_co _checked_metric_device: bool def __init__( self, model: p.ModelProtocol[Input, Output], name: str = '', *, loader: p.LoaderProtocol[tuple[Input, Target]], objective: _Objective_co, ) -> None: """Initialize. Args: model: the model containing the weights to evaluate. name: the name for the object for logging purposes. Defaults to class name plus eventual counter. loader: provides inputs and targets in batches. objective: processes the model outputs and targets. """ super().__init__(model, loader=loader, name=name) self.objective: Final = copy.deepcopy(objective) if self._is_distributed: if getattr(self.objective, 'sync_on_compute', False): issue = 'sync_on_compute=True will cause overhead' recommend = 'set sync_on_compute to False' warnings.warn( exceptions.ObjectiveSyncWarning(issue, recommend), stacklevel=2, ) if getattr(self.objective, 'dist_sync_on_step', False): issue = 'dist_sync_on_step=True will cause overhead' recommend = 'set dist_sync_on_step to False' warnings.warn( exceptions.ObjectiveSyncWarning(issue, recommend), stacklevel=2, ) self.objective.reset() self._checked_metric_device = False return def _compute_metrics(self) -> Mapping[str, float]: if self._is_distributed and not self._checked_metric_device: try: objectives.check_device(self.objective, self.model.device) except exceptions.DeviceMismatchError as err: raise exceptions.ModelDeviceMismatchError() from err self._checked_metric_device = True self._cached_metrics = objectives.compute_metrics(self.objective) return self._cached_metrics @override def _run_epoch(self, store_outputs: bool): self.objective.reset() # reset before epoch to keep last metrics super()._run_epoch(store_outputs) return @override def _run_backward(self, outputs: Output, targets: Target) -> None: self.objective.update(outputs, targets) super()._run_backward(outputs, targets) return @override def _sync(self) -> None: if isinstance(self.objective, SupportsSync): self.objective.sync() else: issue = 'metrics not synchronized (averaged) across processes' recommend = "override Runner's `_sync` method" warnings.warn( exceptions.ObjectiveSyncWarning(issue, recommend), stacklevel=2, )
[docs] class ModelRunnerWithLogs( ModelRunnerWithObjective[Input, Target, Output, _Objective_co] ): """Run a model on a dataset and log the value of an objective function. Attributes: model: the model containing the weights to evaluate. loader: provides inputs and targets in batches. objective: processes the model outputs and targets. outputs_list: list of optionally stored outputs. """ def _run_epoch(self, store_outputs: bool): super()._run_epoch(store_outputs) log_events.MetricEvent( model_name=self.model.name, source_name=self.name, epoch=self.model.epoch, metrics=self._compute_metrics(), ) return