Source code for drytorch.lib.models

"""Module containing classes for wrapping a torch module and its optimizer."""

from __future__ import annotations

import abc
import sys

from collections.abc import Callable
from typing import ClassVar, Final, Protocol, TypeVar

import torch

from torch import distributed as dist
from torch.nn import parallel
from typing_extensions import override

from drytorch.core import protocols as p
from drytorch.core import registering
from drytorch.lib import checkpoints
from drytorch.utils import repr_utils


__all__ = [
    'AveragedModel',
    'EMAModel',
    'Model',
    'SWAModel',
]

Input = TypeVar('Input', bound=p.InputType, contravariant=True)
Output = TypeVar('Output', bound=p.OutputType, covariant=True)
Tensor = torch.Tensor

_ParamList = tuple[Tensor, ...] | list[Tensor]
_MultiAvgFn = Callable[[_ParamList, _ParamList, Tensor | int], None]


class ModuleProtocol(Protocol[Input, Output]):
    """Protocol for a PyTorch module with type annotations."""

    @abc.abstractmethod
    def forward(self, inputs: Input, /) -> Output:
        """Forward run of the network."""


[docs] class Model(repr_utils.CreatedAtMixin, p.ModelProtocol[Input, Output]): """Wrapper for a torch.nn.Module class with extra information. Attributes: exec_module: Pytorch module used for execution. epoch: the number of epochs the model has been trained so far. mixed_precision: whether to use mixed precision computing. checkpoint: checkpoint manager. """ _name = repr_utils.DefaultName() exec_module: torch.nn.Module epoch: int mixed_precision: bool checkpoint: p.CheckpointProtocol _device: torch.device _should_compile: bool _should_dist: bool _registered: bool def __init__( # type: ignore self, module: ModuleProtocol[Input, Output], name: str = '', device: torch.device | None = None, checkpoint: p.CheckpointProtocol | None = None, mixed_precision: bool = False, should_compile: bool = True, should_distribute: bool = True, ) -> None: """Initialize. Option should_distribute assumes that there is a single accelerator for each process and that the device for the process is already set. Args: module: Pytorch module with type annotations. name: the name of the model. Default uses the class name. device: the device where to store the weights of the module. Default uses the accelerator if available, cpu otherwise. checkpoint: class that saves the state and optionally the optimizer. mixed_precision: whether to use mixed precision computing. should_compile: compile the module at instantiation (Python < 3.14). should_distribute: wrap the module for data-distributed settings. """ super().__init__() self._device = self._default_device() if device is None else device self._should_compile = should_compile self._should_dist = should_distribute self.mixed_precision: Final = mixed_precision torch_module = self._validate_module(module) self.exec_module: Final = self.prepare_module(torch_module) self._name = name self.epoch = 0 if checkpoint is None: checkpoint = checkpoints.LocalCheckpoint() self.checkpoint = checkpoint self.checkpoint.bind_model(self) self._registered = False self.register() return
[docs] def __call__(self, inputs: Input) -> Output: """Execute forward pass.""" with torch.autocast( device_type=self.device.type, enabled=self.mixed_precision ): return self.exec_module(inputs)
[docs] def __del__(self): """Unregister from the registry when deleted/garbage-collected.""" try: self.unregister() except AttributeError: # may happen during instantiation pass return
@property def device(self) -> torch.device: """The device where the weights are stored.""" return self._device @property def module(self) -> torch.nn.Module: """The module wrapped by the class.""" return self._unwrap_module() @property def name(self) -> str: """The name of the model.""" return self._name
[docs] def prepare_module(self, module: torch.nn.Module) -> torch.nn.Module: """Compile and distribute the module.""" module = module.to(self._device) # TODO: remove flag when torch.compile is supported on Python 3.14 if self._should_compile and sys.version_info < (3, 14): torch.compile(module) if dist.is_available() and dist.is_initialized() and self._should_dist: if self._device.type == 'cuda': module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module) if self.device.index is not None: module = torch.nn.parallel.DistributedDataParallel( module, device_ids=[self.device.index] ) else: module = torch.nn.parallel.DistributedDataParallel(module) return module
[docs] def increment_epoch(self) -> None: """Increment the epoch by 1.""" self.epoch += 1
[docs] def load_state(self, epoch=-1) -> None: """Load the weights and epoch of the model.""" self.checkpoint.load(epoch=epoch)
[docs] def register(self) -> None: """Register to the registry.""" registering.register_model(self) self._registered = True return
[docs] def save_state(self) -> None: """Save the weights and epoch of the model.""" self.checkpoint.save()
[docs] def unregister(self) -> None: """Unregister from the registry.""" if self._registered: registering.unregister_model(self) self._registered = False return
def _unwrap_module(self) -> torch.nn.Module: """Return the module without wrapping.""" if isinstance(self.exec_module, parallel.DistributedDataParallel): return self.exec_module.module return self.exec_module
[docs] def post_batch_update(self) -> None: """Update the model after processing a batch of data.""" return
[docs] def post_epoch_update(self) -> None: """Update the model after processing an epoch of data.""" return
@staticmethod def _default_device() -> torch.device: device = torch.accelerator.current_accelerator() if device is not None: index = torch.accelerator.current_device_index() return torch.device(device.type, index) return torch.device('cpu') @staticmethod def _validate_module( torch_model: ModuleProtocol[Input, Output], ) -> torch.nn.Module: if not isinstance(torch_model, torch.nn.Module): raise TypeError('torch_module must be a torch.nn.Module subclass') return torch_model
[docs] class AveragedModel(Model[Input, Output], abc.ABC): """Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel. Use the averaged model when in inference mode. Attributes: exec_averaged_module: the averaged module. """ average_name: ClassVar[str] = 'averaged_model' exec_averaged_module: torch.optim.swa_utils.AveragedModel def __init__( self, torch_module: ModuleProtocol[Input, Output], /, name: str = '', device: torch.device | None = None, checkpoint: p.CheckpointProtocol | None = None, mixed_precision: bool = False, ) -> None: """Initialize. Args: torch_module: Pytorch module with type annotations. name: the name of the model. Default uses the class name. device: the device where to store the weights of the module. Default uses cuda when available, cpu otherwise. checkpoint: class that saves the state and optionally the optimizer. mixed_precision: whether to use mixed precision computing. Defaults to False. """ super().__init__( torch_module, name, device, checkpoint, mixed_precision ) self.exec_averaged_module = self._create_averaged_module() self.checkpoint.bind_module( self.average_name, self.exec_averaged_module, # save wrapped module ) return
[docs] def __call__(self, inputs: Input) -> Output: """Execute the forward pass.""" if torch.is_inference_mode_enabled(): return self.exec_averaged_module(inputs) # no mixed precision here return super().__call__(inputs)
@property def averaged_module(self) -> torch.nn.Module: """The module wrapped by the class.""" return self._unwrap_averaged_module() def _create_averaged_module(self) -> torch.optim.swa_utils.AveragedModel: averaged_module = torch.optim.swa_utils.AveragedModel( self.module, self.device, multi_avg_fn=self._get_multi_avg_fn(), use_buffers=True, ) averaged_module.eval() for param in averaged_module.parameters(): param.requires_grad_(False) return averaged_module @abc.abstractmethod def _get_multi_avg_fn(self) -> _MultiAvgFn | None: """Define the averaging function for the model parameters.""" def _update_parameters(self) -> None: self.exec_averaged_module.update_parameters(self._unwrap_module()) return def _unwrap_averaged_module(self) -> torch.nn.Module: return self.exec_averaged_module.module
[docs] class SWAModel(AveragedModel[Input, Output]): """Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel. Use the averaged model when in inference mode. Attributes: exec_averaged_module: the averaged module. start_epoch: the epoch at which to start averaging. """ average_name = 'swa_model' exec_averaged_module: torch.optim.swa_utils.AveragedModel def __init__( self, torch_module: ModuleProtocol[Input, Output], /, start_epoch: int, name: str = '', device: torch.device | None = None, checkpoint: p.CheckpointProtocol | None = None, mixed_precision: bool = False, ) -> None: """Initialize. Args: torch_module: Pytorch module with type annotations. start_epoch: the epoch at which to start averaging. name: the name of the model. Default uses the class name. device: the device where to store the weights of the module. Default uses cuda when available, cpu otherwise. checkpoint: class that saves the state and optionally the optimizer. mixed_precision: whether to use mixed precision computing. Defaults to False. """ self.start_epoch: Final = start_epoch super().__init__( torch_module, name, device, checkpoint, mixed_precision ) return
[docs] def __call__(self, inputs: Input) -> Output: """Execute the forward pass.""" if torch.is_inference_mode_enabled() and self.epoch >= self.start_epoch: return self.averaged_module(inputs) # no mixed precision here return super(AveragedModel, self).__call__(inputs)
[docs] @override def post_epoch_update(self) -> None: if self.epoch >= self.start_epoch: self._update_parameters() return
@override def _get_multi_avg_fn(self) -> None: return None
[docs] class EMAModel(AveragedModel[Input, Output]): """Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel. Use the averaged model when in inference mode. Attributes: exec_averaged_module: the averaged module. decay: the exponential decay rate for the moving average. """ average_name = 'ema_model' exec_averaged_module: torch.optim.swa_utils.AveragedModel decay: float def __init__( self, torch_module: ModuleProtocol[Input, Output], /, name: str = '', device: torch.device | None = None, checkpoint: p.CheckpointProtocol | None = None, mixed_precision: bool = False, decay: float = 0.999, ) -> None: """Initialize. Args: torch_module: Pytorch module with type annotations. name: the name of the model. Default uses the class name. device: the device where to store the weights of the module. Default uses cuda when available, cpu otherwise. checkpoint: class that saves the state and optionally the optimizer. mixed_precision: whether to use mixed precision computing. Defaults to False. decay: the exponential decay rate for the moving average. """ self.decay: Final = decay super().__init__( torch_module, name, device, checkpoint, mixed_precision ) return @override def _get_multi_avg_fn(self) -> _MultiAvgFn: return torch.optim.swa_utils.get_ema_multi_avg_fn(decay=self.decay)
[docs] @override def post_batch_update(self) -> None: self._update_parameters() return