drytorch.core.protocols
Module containing internal protocols.
Classes
|
Protocol that stores and loads weight for a ModelProtocol class. |
|
Abstract base class for gradient operations. |
|
Protocol with specifications for the learning algorithm. |
|
Protocol loading and batching a dataset. |
|
Protocol that calculates and returns metrics and the loss. |
|
Protocol for a wrapper around a torch module. |
|
Protocol for a class that validates a model. |
|
Protocol that calculates and returns metrics. |
|
Protocol of a scheduler for the learning rate. |
|
Protocol for a class that train a model. |
- class CheckpointProtocol(*args, **kwargs)[source]
Bases:
ProtocolProtocol that stores and loads weight for a ModelProtocol class.
- bind_model(model: ModelProtocol[Any, Any]) None[source]
Bind the model to manage.
- Parameters:
model (ModelProtocol[Any, Any])
- Return type:
None
- class GradientOpProtocol(*args, **kwargs)[source]
Bases:
ProtocolAbstract base class for gradient operations.
- class LearningProtocol(*args, **kwargs)[source]
Bases:
ProtocolProtocol with specifications for the learning algorithm.
- optimizer_cls
the optimizer class to bind to the module.
- base_lr
initial learning rates for named parameters or global value.
- scheduler
modifies the learning rate given the current epoch.
- class LoaderProtocol(*args, **kwargs)[source]
Bases:
Protocol[_Data_co]Protocol loading and batching a dataset.
- dataset
the dataset to load.
- Type:
- sampler
the sampler used to select the samples.
- Type:
torch.utils.data.sampler.Sampler[Any] | collections.abc.Iterable[Any]
- class LossProtocol(*args, **kwargs)[source]
Bases:
ObjectiveProtocol[_Output_contra,_Target_contra],ProtocolProtocol that calculates and returns metrics and the loss.
- class ModelProtocol(*args, **kwargs)[source]
Bases:
Protocol[_Input_contra,_Output_co]Protocol for a wrapper around a torch module.
- checkpoint
the object responsible for saving and loading the model.
- abstractmethod __call__(inputs: _Input_contra) _Output_co[source]
Call the module forward method.
- Parameters:
inputs (_Input_contra)
- Return type:
_Output_co
- class MonitorProtocol(*args, **kwargs)[source]
Bases:
ProtocolProtocol for a class that validates a model.
- model
the model to evaluate.
- Type:
drytorch.core.protocols.ModelProtocol[Any, Any]
- class ObjectiveProtocol(*args, **kwargs)[source]
Bases:
Protocol[_Output_contra,_Target_contra]Protocol that calculates and returns metrics.
- abstractmethod update(outputs: _Output_contra, targets: _Target_contra, /) Any[source]
Compute the metrics only.
- Parameters:
outputs (_Output_contra) – model outputs.
targets (_Target_contra) – ground truth.
- Return type:
- class SchedulerProtocol(*args, **kwargs)[source]
Bases:
ProtocolProtocol of a scheduler for the learning rate.
- class TrainerProtocol(*args, **kwargs)[source]
Bases:
MonitorProtocol,Protocol[_Input,_Target,_Output]Protocol for a class that train a model.
- model
the model to train.
- Type:
drytorch.core.protocols.ModelProtocol[drytorch.core.protocols._Input, drytorch.core.protocols._Output]
- learning_schema
contains optimizer settings and scheduling.
- objective
determines the optimization’s criterion.
- Type:
drytorch.core.protocols.LossProtocol[drytorch.core.protocols._Output, drytorch.core.protocols._Target]
- validation
class that validates the model,
- Type:
- train(n_epochs: int) None[source]
Train the module for the specified number of epochs.
- Parameters:
n_epochs (int) – the number of epochs for which train the module.
- Return type:
None
- terminate_training(reason: str) None[source]
Prevent the trainer from continue the training.
- Parameters:
reason (str)
- Return type:
None
- load_checkpoint(epoch: int = -1) None[source]
Load model and optimizer state from a checkpoint.
- Parameters:
epoch (int)
- Return type:
None
- update_learning_rate(base_lr: float | None, scheduler: SchedulerProtocol | None) None[source]
Update the learning rate(s).
It updates the learning rates for each parameter’s group in the optimizer based on input learning rate and scheduler.
- Parameters:
base_lr (float | None) – the initial learning rate.
scheduler (SchedulerProtocol | None) – scheduler for the learning rate.
- Return type:
None