drytorch.lib.models
Module containing classes for wrapping a torch module and its optimizer.
Functions
|
Count the number of parameters. |
Classes
|
Wrapper for a torch.nn.Module class with extra information. |
|
Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel. |
|
Bundle the module and its optimizer. |
- class Model(module: ModuleProtocol[Input, Output], name: str = '', device: device | None = None, checkpoint: CheckpointProtocol | None = None, mixed_precision: bool = False, should_compile: bool = True, should_distribute: bool = True)[source]
Bases:
CreatedAtMixin,ModelProtocol[Input,Output]Wrapper for a torch.nn.Module class with extra information.
- module
Pytorch module to optimize.
- checkpoint
checkpoint manager.
Initialize.
Option should_distribute assumes that there is a single accelerator for each process and that the device for the process is already set.
- Parameters:
module (Module) – Pytorch module with type annotations.
name (str) – the name of the model. Default uses the class name.
device (torch.device | None) – the device where to store the weights of the module. Default uses the accelerator if available, cpu otherwise.
checkpoint (CheckpointProtocol) – class that saves the state and optionally the optimizer.
mixed_precision (bool) – whether to use mixed precision computing.
should_compile (bool) – compile the module at instantiation (Python < 3.14).
should_distribute (bool) – wrap the module for data-distributed settings.
- __call__(inputs: Input) Output[source]
Execute forward pass.
- Parameters:
inputs (Input)
- Return type:
Output
- class ModelAverage(torch_module: ~drytorch.core.protocols.ModuleProtocol[~drytorch.lib.models.Input, ~drytorch.lib.models.Output], /, name: str = '', device: ~torch.device | None = None, checkpoint: ~drytorch.core.protocols.CheckpointProtocol = <drytorch.lib.checkpoints.LocalCheckpoint object>, mixed_precision: bool = False, avg_fn: ~collections.abc.Callable[[~torch.Tensor, ~torch.Tensor, ~torch.Tensor | int], ~torch.Tensor] | None = None, multi_avg_fn: ~collections.abc.Callable[[tuple[~torch.Tensor, ...] | list[~torch.Tensor], tuple[~torch.Tensor, ...] | list[~torch.Tensor], ~torch.Tensor | int], None] | None = None, use_buffers: bool = False)[source]
Bases:
Model[Input,Output]Bundle a torch.nn.Module and a torch.optim.swa_utils.AveragedModel.
Use the averaged model when in inference mode.
- averaged_module
the averaged module.
Initialize.
- Parameters:
torch_module (p.ModuleProtocol[Input, Output]) – Pytorch module with type annotations.
name (str) – the name of the model. Default uses the class name.
device (torch.device | None) – the device where to store the weights of the module. Default uses cuda when available, cpu otherwise.
checkpoint (CheckpointProtocol) – class that saves the state and optionally the optimizer.
mixed_precision (bool) – whether to use mixed precision computing. Defaults to False.
avg_fn (Callable[[Tensor, Tensor, Tensor | int], Tensor] | None) – see docs at torch.optim.swa_utils.AveragedModel.
multi_avg_fn (Callable[[ParamList, ParamList, Tensor | int], None] | None) – see docs at torch.optim.swa_utils.AveragedModel.
use_buffers (bool) – see docs at torch.optim.swa_utils.AveragedModel.
- class ModelOptimizer(model: ModelProtocol[Input, Output], learning_schema: LearningProtocol)[source]
Bases:
objectBundle the module and its optimizer.
It supports different learning rates to separate parameters’ groups.
Initialize.
- Parameters:
model (p.ModelProtocol[Input, Output]) – the model to be optimized.
learning_schema (p.LearningProtocol) – the learning scheme for the optimizer.
- property base_lr: float | dict[str, float]
Learning rate(s) for the module parameters.
- Raises:
MissingParamError – if parameters are missing from the dictionary.
- get_opt_params() list[_OptParams][source]
Actual learning rates for each parameter updated according.
- Return type:
list[_OptParams]
- get_scheduled_lr(lr: float) float[source]
Update the base learning rate according to the scheduler.
- load(epoch: int = -1) None[source]
Load model and optimizer state from a checkpoint.
- Parameters:
epoch (int)
- Return type:
None
- update_learning_rate(base_lr: float | dict[str, float] | None = None, scheduler: SchedulerProtocol | None = None) None[source]
Recalculate the learning rates for the current epoch.
It updates the learning rates for each parameter’s group in the optimizer based on input learning rate(s) and scheduler.
- Parameters:
base_lr (float | dict[str, float] | None) – initial learning rates for named parameters or global value. Default keeps the original learning rates.
scheduler (SchedulerProtocol | None) – scheduler for the learning rates. Default keeps the original scheduler.
- Return type:
None