"""Module containing classes for wrapping a torch module and its optimizer."""
from __future__ import annotations
import abc
from collections.abc import Callable
from typing import ClassVar, Final, Protocol, TypeVar
import torch
from torch import distributed as dist
from torch.nn import parallel
from typing_extensions import override
from drytorch.core import protocols as p
from drytorch.core import registering
from drytorch.lib import checkpoints
from drytorch.utils import repr_utils
__all__ = [
'AveragedModel',
'EMAModel',
'Model',
'SWAModel',
]
Input = TypeVar('Input', bound=p.InputType, contravariant=True)
Output = TypeVar('Output', bound=p.OutputType, covariant=True)
Tensor = torch.Tensor
_ParamList = tuple[Tensor, ...] | list[Tensor]
_MultiAvgFn = Callable[[_ParamList, _ParamList, Tensor | int], None]
class ModuleProtocol(Protocol[Input, Output]):
"""Protocol for a PyTorch module with type annotations."""
@abc.abstractmethod
def forward(self, inputs: Input, /) -> Output:
"""Forward run of the network."""
[docs]
class Model(repr_utils.CreatedAtMixin, p.ModelProtocol[Input, Output]):
"""Wrapper for a torch.nn.Module class with extra information.
Attributes:
exec_module: Pytorch module used for execution.
epoch: the number of epochs the model has been trained so far.
mixed_precision: whether to use mixed precision computing.
checkpoint: checkpoint manager.
"""
_name = repr_utils.DefaultName()
exec_module: torch.nn.Module
epoch: int
mixed_precision: bool
checkpoint: p.CheckpointProtocol
_device: torch.device
_should_compile: bool
_should_dist: bool
_registered: bool
def __init__( # type: ignore
self,
module: ModuleProtocol[Input, Output],
name: str = '',
device: torch.device | None = None,
checkpoint: p.CheckpointProtocol | None = None,
mixed_precision: bool = False,
should_compile: bool = True,
should_distribute: bool = True,
) -> None:
"""Initialize.
Option should_distribute assumes that there is a single accelerator for
each process and that the device for the process is already set.
Args:
module: Pytorch module with type annotations.
name: the name of the model. Default uses the class name.
device: the device where to store the weights of the module.
Default uses the accelerator if available, cpu otherwise.
checkpoint: class that saves the state and optionally the optimizer.
mixed_precision: whether to use mixed precision computing.
should_compile: compile the module at instantiation (Python < 3.14).
should_distribute: wrap the module for data-distributed settings.
"""
super().__init__()
self._device = self._default_device() if device is None else device
self._should_compile = should_compile
self._should_dist = should_distribute
self.mixed_precision: Final = mixed_precision
torch_module = self._validate_module(module)
self.exec_module: Final = self.prepare_module(torch_module)
self._name = name
self.epoch = 0
if checkpoint is None:
checkpoint = checkpoints.LocalCheckpoint()
self.checkpoint = checkpoint
self.checkpoint.bind_model(self)
self._registered = False
self.register()
return
[docs]
def __call__(self, inputs: Input) -> Output:
"""Execute forward pass."""
with torch.autocast(
device_type=self.device.type, enabled=self.mixed_precision
):
return self.exec_module(inputs)
[docs]
def __del__(self):
"""Unregister from the registry when deleted/garbage-collected."""
try:
self.unregister()
except AttributeError: # may happen during instantiation
pass
return
@property
def device(self) -> torch.device:
"""The device where the weights are stored."""
return self._device
@property
def module(self) -> torch.nn.Module:
"""The module wrapped by the class."""
return self._unwrap_module()
@property
def name(self) -> str:
"""The name of the model."""
return self._name
[docs]
def prepare_module(self, module: torch.nn.Module) -> torch.nn.Module:
"""Compile and distribute the module."""
module = module.to(self._device)
if self._should_compile:
torch.compile(module)
if dist.is_available() and dist.is_initialized() and self._should_dist:
if self._device.type == 'cuda':
module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
if self.device.index is not None:
module = torch.nn.parallel.DistributedDataParallel(
module, device_ids=[self.device.index]
)
else:
module = torch.nn.parallel.DistributedDataParallel(module)
return module
[docs]
def increment_epoch(self) -> None:
"""Increment the epoch by 1."""
self.epoch += 1
[docs]
def load_state(self, epoch=-1) -> None:
"""Load the weights and epoch of the model."""
self.checkpoint.load(epoch=epoch)
[docs]
def register(self) -> None:
"""Register to the registry."""
registering.register_model(self)
self._registered = True
return
[docs]
def save_state(self) -> None:
"""Save the weights and epoch of the model."""
self.checkpoint.save()
[docs]
def unregister(self) -> None:
"""Unregister from the registry."""
if self._registered:
registering.unregister_model(self)
self._registered = False
return
def _unwrap_module(self) -> torch.nn.Module:
"""Return the module without wrapping."""
if isinstance(self.exec_module, parallel.DistributedDataParallel):
return self.exec_module.module
return self.exec_module
[docs]
def post_batch_update(self) -> None:
"""Update the model after processing a batch of data."""
return
[docs]
def post_epoch_update(self) -> None:
"""Update the model after processing an epoch of data."""
return
@staticmethod
def _default_device() -> torch.device:
device = torch.accelerator.current_accelerator()
if device is not None:
index = torch.accelerator.current_device_index()
return torch.device(device.type, index)
return torch.device('cpu')
@staticmethod
def _validate_module(
torch_model: ModuleProtocol[Input, Output],
) -> torch.nn.Module:
if not isinstance(torch_model, torch.nn.Module):
raise TypeError('torch_module must be a torch.nn.Module subclass')
return torch_model
[docs]
class AveragedModel(Model[Input, Output], abc.ABC):
"""Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel.
Use the averaged model when in inference mode.
Attributes:
exec_averaged_module: the averaged module.
"""
average_name: ClassVar[str] = 'averaged_model'
exec_averaged_module: torch.optim.swa_utils.AveragedModel
def __init__(
self,
torch_module: ModuleProtocol[Input, Output],
/,
name: str = '',
device: torch.device | None = None,
checkpoint: p.CheckpointProtocol | None = None,
mixed_precision: bool = False,
) -> None:
"""Initialize.
Args:
torch_module: Pytorch module with type annotations.
name: the name of the model. Default uses the class name.
device: the device where to store the weights of the module.
Default uses cuda when available, cpu otherwise.
checkpoint: class that saves the state and optionally the optimizer.
mixed_precision: whether to use mixed precision computing.
Defaults to False.
"""
super().__init__(
torch_module, name, device, checkpoint, mixed_precision
)
self.exec_averaged_module = self._create_averaged_module()
self.checkpoint.bind_module(
self.average_name,
self.exec_averaged_module, # save wrapped module
)
return
[docs]
def __call__(self, inputs: Input) -> Output:
"""Execute the forward pass."""
if torch.is_inference_mode_enabled():
return self.exec_averaged_module(inputs) # no mixed precision here
return super().__call__(inputs)
@property
def averaged_module(self) -> torch.nn.Module:
"""The module wrapped by the class."""
return self._unwrap_averaged_module()
def _create_averaged_module(self) -> torch.optim.swa_utils.AveragedModel:
averaged_module = torch.optim.swa_utils.AveragedModel(
self.module,
self.device,
multi_avg_fn=self._get_multi_avg_fn(),
use_buffers=True,
)
averaged_module.eval()
for param in averaged_module.parameters():
param.requires_grad_(False)
return averaged_module
@abc.abstractmethod
def _get_multi_avg_fn(self) -> _MultiAvgFn | None:
"""Define the averaging function for the model parameters."""
def _update_parameters(self) -> None:
self.exec_averaged_module.update_parameters(self._unwrap_module())
return
def _unwrap_averaged_module(self) -> torch.nn.Module:
return self.exec_averaged_module.module
[docs]
class SWAModel(AveragedModel[Input, Output]):
"""Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel.
Use the averaged model when in inference mode.
Attributes:
exec_averaged_module: the averaged module.
start_epoch: the epoch at which to start averaging.
"""
average_name = 'swa_model'
exec_averaged_module: torch.optim.swa_utils.AveragedModel
def __init__(
self,
torch_module: ModuleProtocol[Input, Output],
/,
start_epoch: int,
name: str = '',
device: torch.device | None = None,
checkpoint: p.CheckpointProtocol | None = None,
mixed_precision: bool = False,
) -> None:
"""Initialize.
Args:
torch_module: Pytorch module with type annotations.
start_epoch: the epoch at which to start averaging.
name: the name of the model. Default uses the class name.
device: the device where to store the weights of the module.
Default uses cuda when available, cpu otherwise.
checkpoint: class that saves the state and optionally the optimizer.
mixed_precision: whether to use mixed precision computing.
Defaults to False.
"""
self.start_epoch: Final = start_epoch
super().__init__(
torch_module, name, device, checkpoint, mixed_precision
)
return
[docs]
def __call__(self, inputs: Input) -> Output:
"""Execute the forward pass."""
if torch.is_inference_mode_enabled() and self.epoch >= self.start_epoch:
return self.averaged_module(inputs) # no mixed precision here
return super(AveragedModel, self).__call__(inputs)
[docs]
@override
def post_epoch_update(self) -> None:
if self.epoch >= self.start_epoch:
self._update_parameters()
return
@override
def _get_multi_avg_fn(self) -> None:
return None
[docs]
class EMAModel(AveragedModel[Input, Output]):
"""Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel.
Use the averaged model when in inference mode.
Attributes:
exec_averaged_module: the averaged module.
decay: the exponential decay rate for the moving average.
"""
average_name = 'ema_model'
exec_averaged_module: torch.optim.swa_utils.AveragedModel
decay: float
def __init__(
self,
torch_module: ModuleProtocol[Input, Output],
/,
name: str = '',
device: torch.device | None = None,
checkpoint: p.CheckpointProtocol | None = None,
mixed_precision: bool = False,
decay: float = 0.999,
) -> None:
"""Initialize.
Args:
torch_module: Pytorch module with type annotations.
name: the name of the model. Default uses the class name.
device: the device where to store the weights of the module.
Default uses cuda when available, cpu otherwise.
checkpoint: class that saves the state and optionally the optimizer.
mixed_precision: whether to use mixed precision computing.
Defaults to False.
decay: the exponential decay rate for the moving average.
"""
self.decay: Final = decay
super().__init__(
torch_module, name, device, checkpoint, mixed_precision
)
return
@override
def _get_multi_avg_fn(self) -> _MultiAvgFn:
return torch.optim.swa_utils.get_ema_multi_avg_fn(decay=self.decay)
[docs]
@override
def post_batch_update(self) -> None:
self._update_parameters()
return