Source code for drytorch.core.protocols

"""Module containing internal protocols."""

from __future__ import annotations

import abc

from collections.abc import Iterable, Iterator, Mapping, MutableSequence
from typing import (
    Any,
    NamedTuple,
    Protocol,
    TypeAlias,
    TypeVar,
    Union,
    runtime_checkable,
)

import torch

from torch.utils import data


__all__ = [
    'CheckpointProtocol',
    'GradientOpProtocol',
    'InputType',
    'LearningProtocol',
    'LoaderProtocol',
    'LossProtocol',
    'ModelProtocol',
    'MonitorProtocol',
    'ObjectiveProtocol',
    'OutputType',
    'SchedulerProtocol',
    'TargetType',
    'Tensors',
    'TrainerProtocol',
]

# pyright: reportReturnType=false

_T = TypeVar('_T')
Tensors: TypeAlias = torch.Tensor | MutableSequence[torch.Tensor]
InputType: TypeAlias = Union[Tensors, NamedTuple]
OutputType: TypeAlias = Any
TargetType: TypeAlias = Union[Tensors, NamedTuple]

_Data_co = TypeVar(
    '_Data_co', bound=tuple[InputType, TargetType], covariant=True
)
_Output_co = TypeVar('_Output_co', bound=OutputType, covariant=True)

_Input_contra = TypeVar('_Input_contra', bound=InputType, contravariant=True)
_Target_contra = TypeVar('_Target_contra', bound=TargetType, contravariant=True)
_Output_contra = TypeVar('_Output_contra', bound=OutputType, contravariant=True)

_Input = TypeVar('_Input', bound=InputType)
_Target = TypeVar('_Target', bound=TargetType)
_Output = TypeVar('_Output', bound=OutputType)


[docs] class LoaderProtocol(Protocol[_Data_co]): """Protocol loading and batching a dataset. Attributes: batch_size: the batch size. dataset: the dataset to load. sampler: the sampler used to select the samples. """ batch_size: int | None dataset: data.Dataset[Any] sampler: torch.utils.data.Sampler[Any] | Iterable[Any]
[docs] def __iter__(self) -> Iterator[_Data_co]: """Return an iterator over the dataset in batches."""
[docs] def __len__(self) -> int: """Return the number of batches in the dataset."""
[docs] @runtime_checkable class ModelProtocol(Protocol[_Input_contra, _Output_co]): """Protocol for a wrapper around a torch module. Attributes: epoch: the number of epochs the model has been trained so far. checkpoint: the object responsible for saving and loading the model. mixed_precision: whether to use mixed precision computing. """ epoch: int checkpoint: CheckpointProtocol mixed_precision: bool
[docs] @abc.abstractmethod def __call__(self, /, inputs: _Input_contra) -> _Output_co: """Call the module forward method."""
@property def device(self) -> torch.device: """The device where the weights are stored.""" @property def module(self) -> torch.nn.Module: """The module wrapped by the class.""" @property def name(self) -> str: """The name of the model."""
[docs] @abc.abstractmethod def increment_epoch(self) -> None: """Increment the epoch by 1."""
[docs] @abc.abstractmethod def post_batch_update(self) -> None: """Update the model after processing a batch of data."""
[docs] @abc.abstractmethod def post_epoch_update(self) -> None: """Update the model after processing an epoch of data."""
[docs] class CheckpointProtocol(Protocol): """Protocol that stores and loads weight for a ModelProtocol class."""
[docs] def bind_model(self, model: ModelProtocol[Any, Any]) -> None: """Bind the model to manage."""
[docs] def bind_module(self, name: str, module: torch.nn.Module) -> None: """Bind a module connected to the model."""
[docs] def bind_optimizer(self, optimizer: torch.optim.Optimizer) -> None: """Bind the optimizer connected to the model."""
[docs] def save(self) -> None: """Save the model and optimizer state dictionaries."""
[docs] def load(self, epoch: int = -1) -> None: """Load the model and optimizer state dictionaries."""
[docs] class SchedulerProtocol(Protocol): """Protocol of a scheduler for the learning rate."""
[docs] def __call__(self, base_lr: float, epoch: int) -> float: """Modify the learning rate according to a schedule. Args: base_lr: initial learning rate. epoch: the current epoch. Returns: The scheduled value for the learning rate. """
[docs] class GradientOpProtocol(Protocol): """Abstract base class for gradient operations."""
[docs] @abc.abstractmethod def __call__(self, params: Iterable[torch.nn.Parameter]) -> None: """Apply the gradient operation to the given parameters."""
[docs] class LearningProtocol(Protocol): """Protocol with specifications for the learning algorithm. Attributes: optimizer_cls: the optimizer class to bind to the module. base_lr: initial learning rates for named parameters or global value. optimizer_defaults: optional arguments for the optimizer. scheduler: modifies the learning rate given the current epoch. """ optimizer_cls: type[torch.optim.Optimizer] base_lr: float | dict[str, float] scheduler: SchedulerProtocol optimizer_defaults: dict[str, Any] gradient_op: GradientOpProtocol
[docs] @runtime_checkable class ObjectiveProtocol(Protocol[_Output_contra, _Target_contra]): """Protocol that calculates and returns metrics."""
[docs] @abc.abstractmethod def update( self, outputs: _Output_contra, targets: _Target_contra, / ) -> Any: """Compute the metrics only. Args: outputs: model outputs. targets: ground truth. """
[docs] @abc.abstractmethod def compute(self) -> Mapping[str, torch.Tensor] | torch.Tensor | None: """Return a mapping from the metric names to the calculated values."""
[docs] @abc.abstractmethod def reset(self) -> Any: """Reset cached values."""
[docs] @runtime_checkable class LossProtocol(ObjectiveProtocol[_Output_contra, _Target_contra], Protocol): """Protocol that calculates and returns metrics and the loss."""
[docs] def forward( self, outputs: _Output_contra, targets: _Target_contra, / ) -> torch.Tensor: """Process the outputs and targets and returns the loss. Args: outputs: model outputs. targets: ground truth. Returns: The computed loss. """
[docs] class MonitorProtocol(Protocol): """Protocol for a class that validates a model. Attributes: model: the model to evaluate. """ model: ModelProtocol[Any, Any] @property def name(self) -> str: """The name of the model.""" @property def computed_metrics(self) -> Mapping[str, float]: """Computed metric values."""
[docs] @runtime_checkable class TrainerProtocol( MonitorProtocol, Protocol[_Input, _Target, _Output], ): """Protocol for a class that train a model. Attributes: model: the model to train. learning_schema: contains optimizer settings and scheduling. objective: determines the optimization's criterion. validation: class that validates the model, """ model: ModelProtocol[_Input, _Output] learning_schema: LearningProtocol objective: LossProtocol[_Output, _Target] validation: MonitorProtocol | None @property def terminated(self) -> bool: """If true, this trainer should not be used for training anymore."""
[docs] def train(self, n_epochs: int) -> None: """Train the module for the specified number of epochs. Args: n_epochs: the number of epochs for which train the module. """
[docs] def terminate_training(self, reason: str) -> None: """Prevent the trainer from continue the training."""
[docs] def save_checkpoint(self) -> None: """Save model and optimizer state in a checkpoint."""
[docs] def load_checkpoint(self, epoch: int = -1) -> None: """Load model and optimizer state from a checkpoint."""
[docs] def update_learning_rate( self, base_lr: float | None, scheduler: SchedulerProtocol | None, ) -> None: """Update the learning rate(s). It updates the learning rates for each parameter's group in the optimizer based on input learning rate and scheduler. Args: base_lr: the initial learning rate. scheduler: scheduler for the learning rate. """