drytorch.lib.training
Module containing classes for training a model.
Classes
|
Bundle the module and its optimizer. |
|
Implement the standard Pytorch training loop. |
- class ModelOptimizer(model: ModelProtocol[Input, Output], learning_schema: LearningProtocol)[source]
Bases:
objectBundle the module and its optimizer.
It supports different learning rates to separate parameters’ groups.
Initialize.
- Parameters:
model (p.ModelProtocol[Input, Output]) – the model to be optimized.
learning_schema (p.LearningProtocol) – the learning scheme for the optimizer.
- property base_lr: float | dict[str, float]
Learning rate(s) for the module parameters.
- Raises:
MissingParamError – if parameters are missing from the dictionary.
- get_opt_params() list[_OptParams][source]
Actual learning rates for each parameter updated according.
- Return type:
list[_OptParams]
- get_scheduled_lr(lr: float) float[source]
Update the base learning rate according to the scheduler.
- update_learning_rate(base_lr: float | dict[str, float] | None = None, scheduler: SchedulerProtocol | None = None) None[source]
Recalculate the learning rates for the current epoch.
It updates the learning rates for each parameter’s group in the optimizer based on input learning rate(s) and scheduler.
- Parameters:
base_lr (float | dict[str, float] | None) – initial learning rates for named parameters or global value. Default keeps the original learning rates.
scheduler (SchedulerProtocol | None) – scheduler for the learning rates. Default keeps the original scheduler.
- Return type:
None
- class Trainer(model: ModelProtocol[Input, Output], name: str = '', *, loader: LoaderProtocol[tuple[Input, Target]], loss: LossProtocol[Output, Target], learning_schema: LearningProtocol)[source]
Bases:
ModelRunnerWithLogs[Input,Target,Output,LossProtocol[Output,Target]],TrainerProtocol[Input,Target,Output]Implement the standard Pytorch training loop.
- model
the model to train.
- Type:
drytorch.core.protocols.ModelProtocol[drytorch.lib.runners.Input, drytorch.lib.runners.Output]
- loader
provides inputs and targets in batches.
- Type:
drytorch.core.protocols.LoaderProtocol[tuple[drytorch.lib.runners.Input, drytorch.lib.runners.Target]]
- objective
determines the optimization’s criterion.
- Type:
drytorch.lib.runners._Objective_co
- learning_schema
contains optimizer settings and scheduling.
- validation
class that validates the model,
- Type:
Initialize.
- Parameters:
model (ModelProtocol[Input, Output]) – the model containing the weights to evaluate.
name (str) – the base name for the object for logging purposes. Defaults to class name plus eventual counter.
loader (LoaderProtocol[tuple[Input, Target]]) – provides inputs and targets in batches.
loss (p.LossProtocol[Output, Target]) – determines the optimization’s criterion.
learning_schema (LearningProtocol) – contains optimizer settings and scheduling.
- __call__(store_outputs: bool = False) None[source]
Train the module for one epoch.
- Parameters:
store_outputs (bool) – whether to store model outputs.
- Return type:
None
- add_validation(val_loader: LoaderProtocol[tuple[Input, Target]], name: str = '', interval: int = 1) None[source]
Add a loader for validation with the same metrics as for training.
If different validation loaders are added, they will all be performed, but only the last will be stored as the instance validation.
- Parameters:
val_loader (LoaderProtocol[tuple[Input, Target]]) – the loader for validation.
name (str) – the name for the validation.
interval (int) – the frequency of validation.
- Raises:
ValueError – if the interval is not strictly positive.
- Return type:
None
- load_checkpoint(epoch: int = -1) None[source]
Load model and optimizer state from a checkpoint.
- Parameters:
epoch (int) – the epoch from which to load the checkpoint. Defaults to the last saved epoch.
- Return type:
None
- terminate_training(reason: str) None[source]
Prevent the trainer from continue the training.
- Parameters:
reason (str)
- Return type:
None
- train(n_epochs: int) None[source]
Train the module for the specified number of epochs.
- Parameters:
n_epochs (int) – the number of epochs for which train the module.
- Return type:
None
- update_learning_rate(base_lr: float | dict[str, float] | None = None, scheduler: SchedulerProtocol | None = None) None[source]
Update the learning rate(s).
It updates the learning rates for each parameter’s group in the optimizer based on input learning rate(s) and scheduler.
- Parameters:
base_lr (float | dict[str, float] | None) – initial learning rates for named parameters or global value. Default keeps the original learning rates.
scheduler (SchedulerProtocol | None) – scheduler for the learning rates. Default keeps the original scheduler.
- Return type:
None