drytorch.lib.models

Module containing classes for wrapping a torch module and its optimizer.

Classes

AveragedModel(torch_module, /[, name, ...])

Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel.

EMAModel(torch_module, /[, name, device, ...])

Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel.

Model(module[, name, device, checkpoint, ...])

Wrapper for a torch.nn.Module class with extra information.

ModuleProtocol(*args, **kwargs)

Protocol for a PyTorch module with type annotations.

SWAModel(torch_module, /, start_epoch[, ...])

Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel.

class AveragedModel(torch_module: ModuleProtocol[Input, Output], /, name: str = '', device: device | None = None, checkpoint: CheckpointProtocol | None = None, mixed_precision: bool = False)[source]

Bases: Model[Input, Output], ABC

Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel.

Use the averaged model when in inference mode.

exec_averaged_module

the averaged module.

Type:

torch.optim.swa_utils.AveragedModel

Initialize.

Parameters:
  • torch_module (ModuleProtocol[Input, Output]) – Pytorch module with type annotations.

  • name (str) – the name of the model. Default uses the class name.

  • device (torch.device | None) – the device where to store the weights of the module. Default uses cuda when available, cpu otherwise.

  • checkpoint (CheckpointProtocol) – class that saves the state and optionally the optimizer.

  • mixed_precision (bool) – whether to use mixed precision computing. Defaults to False.

__call__(inputs: Input) Output[source]

Execute the forward pass.

Parameters:

inputs (Input)

Return type:

Output

property averaged_module: Module

The module wrapped by the class.

class EMAModel(torch_module: ModuleProtocol[Input, Output], /, name: str = '', device: device | None = None, checkpoint: CheckpointProtocol | None = None, mixed_precision: bool = False, decay: float = 0.999)[source]

Bases: AveragedModel[Input, Output]

Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel.

Use the averaged model when in inference mode.

exec_averaged_module

the averaged module.

Type:

torch.optim.swa_utils.AveragedModel

decay

the exponential decay rate for the moving average.

Type:

float

Initialize.

Parameters:
  • torch_module (ModuleProtocol[Input, Output]) – Pytorch module with type annotations.

  • name (str) – the name of the model. Default uses the class name.

  • device (torch.device | None) – the device where to store the weights of the module. Default uses cuda when available, cpu otherwise.

  • checkpoint (CheckpointProtocol) – class that saves the state and optionally the optimizer.

  • mixed_precision (bool) – whether to use mixed precision computing. Defaults to False.

  • decay (float) – the exponential decay rate for the moving average.

post_batch_update() None[source]

Update the model after processing a batch of data.

Return type:

None

class Model(module: ModuleProtocol[Input, Output], name: str = '', device: device | None = None, checkpoint: CheckpointProtocol | None = None, mixed_precision: bool = False, should_compile: bool = True, should_distribute: bool = True)[source]

Bases: CreatedAtMixin, ModelProtocol[Input, Output]

Wrapper for a torch.nn.Module class with extra information.

exec_module

Pytorch module used for execution.

Type:

torch.nn.modules.module.Module

epoch

the number of epochs the model has been trained so far.

Type:

int

mixed_precision

whether to use mixed precision computing.

Type:

bool

checkpoint

checkpoint manager.

Type:

drytorch.core.protocols.CheckpointProtocol

Initialize.

Option should_distribute assumes that there is a single accelerator for each process and that the device for the process is already set.

Parameters:
  • module (ModuleProtocol[Input, Output]) – Pytorch module with type annotations.

  • name (str) – the name of the model. Default uses the class name.

  • device (torch.device | None) – the device where to store the weights of the module. Default uses the accelerator if available, cpu otherwise.

  • checkpoint (CheckpointProtocol) – class that saves the state and optionally the optimizer.

  • mixed_precision (bool) – whether to use mixed precision computing.

  • should_compile (bool) – compile the module at instantiation (Python < 3.14).

  • should_distribute (bool) – wrap the module for data-distributed settings.

__call__(inputs: Input) Output[source]

Execute forward pass.

Parameters:

inputs (Input)

Return type:

Output

__del__()[source]

Unregister from the registry when deleted/garbage-collected.

property device: device

The device where the weights are stored.

property module: Module

The module wrapped by the class.

property name: str

The name of the model.

prepare_module(module: Module) Module[source]

Compile and distribute the module.

Parameters:

module (Module)

Return type:

Module

increment_epoch() None[source]

Increment the epoch by 1.

Return type:

None

load_state(epoch=-1) None[source]

Load the weights and epoch of the model.

Return type:

None

register() None[source]

Register to the registry.

Return type:

None

save_state() None[source]

Save the weights and epoch of the model.

Return type:

None

unregister() None[source]

Unregister from the registry.

Return type:

None

post_batch_update() None[source]

Update the model after processing a batch of data.

Return type:

None

post_epoch_update() None[source]

Update the model after processing an epoch of data.

Return type:

None

class SWAModel(torch_module: ModuleProtocol[Input, Output], /, start_epoch: int, name: str = '', device: device | None = None, checkpoint: CheckpointProtocol | None = None, mixed_precision: bool = False)[source]

Bases: AveragedModel[Input, Output]

Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel.

Use the averaged model when in inference mode.

exec_averaged_module

the averaged module.

Type:

torch.optim.swa_utils.AveragedModel

start_epoch

the epoch at which to start averaging.

Initialize.

Parameters:
  • torch_module (ModuleProtocol[Input, Output]) – Pytorch module with type annotations.

  • start_epoch (int) – the epoch at which to start averaging.

  • name (str) – the name of the model. Default uses the class name.

  • device (torch.device | None) – the device where to store the weights of the module. Default uses cuda when available, cpu otherwise.

  • checkpoint (CheckpointProtocol) – class that saves the state and optionally the optimizer.

  • mixed_precision (bool) – whether to use mixed precision computing. Defaults to False.

__call__(inputs: Input) Output[source]

Execute the forward pass.

Parameters:

inputs (Input)

Return type:

Output

post_epoch_update() None[source]

Update the model after processing an epoch of data.

Return type:

None