drytorch.lib.models
Module containing classes for wrapping a torch module and its optimizer.
Classes
|
Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel. |
|
Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel. |
|
Wrapper for a torch.nn.Module class with extra information. |
|
Protocol for a PyTorch module with type annotations. |
|
Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel. |
- class AveragedModel(torch_module: ModuleProtocol[Input, Output], /, name: str = '', device: device | None = None, checkpoint: CheckpointProtocol | None = None, mixed_precision: bool = False)[source]
Bases:
Model[Input,Output],ABCBundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel.
Use the averaged model when in inference mode.
- exec_averaged_module
the averaged module.
Initialize.
- Parameters:
torch_module (ModuleProtocol[Input, Output]) – Pytorch module with type annotations.
name (str) – the name of the model. Default uses the class name.
device (torch.device | None) – the device where to store the weights of the module. Default uses cuda when available, cpu otherwise.
checkpoint (CheckpointProtocol) – class that saves the state and optionally the optimizer.
mixed_precision (bool) – whether to use mixed precision computing. Defaults to False.
- class EMAModel(torch_module: ModuleProtocol[Input, Output], /, name: str = '', device: device | None = None, checkpoint: CheckpointProtocol | None = None, mixed_precision: bool = False, decay: float = 0.999)[source]
Bases:
AveragedModel[Input,Output]Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel.
Use the averaged model when in inference mode.
- exec_averaged_module
the averaged module.
Initialize.
- Parameters:
torch_module (ModuleProtocol[Input, Output]) – Pytorch module with type annotations.
name (str) – the name of the model. Default uses the class name.
device (torch.device | None) – the device where to store the weights of the module. Default uses cuda when available, cpu otherwise.
checkpoint (CheckpointProtocol) – class that saves the state and optionally the optimizer.
mixed_precision (bool) – whether to use mixed precision computing. Defaults to False.
decay (float) – the exponential decay rate for the moving average.
- class Model(module: ModuleProtocol[Input, Output], name: str = '', device: device | None = None, checkpoint: CheckpointProtocol | None = None, mixed_precision: bool = False, should_compile: bool = True, should_distribute: bool = True)[source]
Bases:
CreatedAtMixin,ModelProtocol[Input,Output]Wrapper for a torch.nn.Module class with extra information.
- exec_module
Pytorch module used for execution.
- checkpoint
checkpoint manager.
Initialize.
Option should_distribute assumes that there is a single accelerator for each process and that the device for the process is already set.
- Parameters:
module (ModuleProtocol[Input, Output]) – Pytorch module with type annotations.
name (str) – the name of the model. Default uses the class name.
device (torch.device | None) – the device where to store the weights of the module. Default uses the accelerator if available, cpu otherwise.
checkpoint (CheckpointProtocol) – class that saves the state and optionally the optimizer.
mixed_precision (bool) – whether to use mixed precision computing.
should_compile (bool) – compile the module at instantiation (Python < 3.14).
should_distribute (bool) – wrap the module for data-distributed settings.
- __call__(inputs: Input) Output[source]
Execute forward pass.
- Parameters:
inputs (Input)
- Return type:
Output
- class SWAModel(torch_module: ModuleProtocol[Input, Output], /, start_epoch: int, name: str = '', device: device | None = None, checkpoint: CheckpointProtocol | None = None, mixed_precision: bool = False)[source]
Bases:
AveragedModel[Input,Output]Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel.
Use the averaged model when in inference mode.
- exec_averaged_module
the averaged module.
- start_epoch
the epoch at which to start averaging.
Initialize.
- Parameters:
torch_module (ModuleProtocol[Input, Output]) – Pytorch module with type annotations.
start_epoch (int) – the epoch at which to start averaging.
name (str) – the name of the model. Default uses the class name.
device (torch.device | None) – the device where to store the weights of the module. Default uses cuda when available, cpu otherwise.
checkpoint (CheckpointProtocol) – class that saves the state and optionally the optimizer.
mixed_precision (bool) – whether to use mixed precision computing. Defaults to False.