drytorch.core.protocols

Module containing internal protocols.

Classes

CheckpointProtocol(*args, **kwargs)

Protocol that stores and loads weight for a ModelProtocol class.

GradientOpProtocol(*args, **kwargs)

Abstract base class for gradient operations.

LearningProtocol(*args, **kwargs)

Protocol with specifications for the learning algorithm.

LoaderProtocol(*args, **kwargs)

Protocol loading and batching a dataset.

LossProtocol(*args, **kwargs)

Protocol that calculates and returns metrics and the loss.

ModelProtocol(*args, **kwargs)

Protocol for a wrapper around a torch module.

MonitorProtocol(*args, **kwargs)

Protocol for a class that validates a model.

ObjectiveProtocol(*args, **kwargs)

Protocol that calculates and returns metrics.

SchedulerProtocol(*args, **kwargs)

Protocol of a scheduler for the learning rate.

TrainerProtocol(*args, **kwargs)

Protocol for a class that train a model.

class CheckpointProtocol(*args, **kwargs)[source]

Bases: Protocol

Protocol that stores and loads weight for a ModelProtocol class.

bind_model(model: ModelProtocol[Any, Any]) None[source]

Bind the model to manage.

Parameters:

model (ModelProtocol[Any, Any])

Return type:

None

bind_module(name: str, module: Module) None[source]

Bind a module connected to the model.

Parameters:
Return type:

None

bind_optimizer(optimizer: Optimizer) None[source]

Bind the optimizer connected to the model.

Parameters:

optimizer (Optimizer)

Return type:

None

save() None[source]

Save the model and optimizer state dictionaries.

Return type:

None

load(epoch: int = -1) None[source]

Load the model and optimizer state dictionaries.

Parameters:

epoch (int)

Return type:

None

class GradientOpProtocol(*args, **kwargs)[source]

Bases: Protocol

Abstract base class for gradient operations.

abstractmethod __call__(params: Iterable[Parameter]) None[source]

Apply the gradient operation to the given parameters.

Parameters:

params (Iterable[Parameter])

Return type:

None

class LearningProtocol(*args, **kwargs)[source]

Bases: Protocol

Protocol with specifications for the learning algorithm.

optimizer_cls

the optimizer class to bind to the module.

Type:

type[torch.optim.optimizer.Optimizer]

base_lr

initial learning rates for named parameters or global value.

Type:

float | dict[str, float]

optimizer_defaults

optional arguments for the optimizer.

Type:

dict[str, Any]

scheduler

modifies the learning rate given the current epoch.

Type:

drytorch.core.protocols.SchedulerProtocol

class LoaderProtocol(*args, **kwargs)[source]

Bases: Protocol[_Data_co]

Protocol loading and batching a dataset.

batch_size

the batch size.

Type:

int | None

dataset

the dataset to load.

Type:

torch.utils.data.dataset.Dataset[Any]

sampler

the sampler used to select the samples.

Type:

torch.utils.data.sampler.Sampler[Any] | collections.abc.Iterable[Any]

__iter__() Iterator[_Data_co][source]

Return an iterator over the dataset in batches.

Return type:

Iterator[_Data_co]

__len__() int[source]

Return the number of batches in the dataset.

Return type:

int

class LossProtocol(*args, **kwargs)[source]

Bases: ObjectiveProtocol[_Output_contra, _Target_contra], Protocol

Protocol that calculates and returns metrics and the loss.

forward(outputs: _Output_contra, targets: _Target_contra, /) Tensor[source]

Process the outputs and targets and returns the loss.

Parameters:
  • outputs (_Output_contra) – model outputs.

  • targets (_Target_contra) – ground truth.

Returns:

The computed loss.

Return type:

Tensor

class ModelProtocol(*args, **kwargs)[source]

Bases: Protocol[_Input_contra, _Output_co]

Protocol for a wrapper around a torch module.

epoch

the number of epochs the model has been trained so far.

Type:

int

checkpoint

the object responsible for saving and loading the model.

Type:

drytorch.core.protocols.CheckpointProtocol

mixed_precision

whether to use mixed precision computing.

Type:

bool

abstractmethod __call__(inputs: _Input_contra) _Output_co[source]

Call the module forward method.

Parameters:

inputs (_Input_contra)

Return type:

_Output_co

property device: device

The device where the weights are stored.

property module: Module

The module wrapped by the class.

property name: str

The name of the model.

abstractmethod increment_epoch() None[source]

Increment the epoch by 1.

Return type:

None

abstractmethod post_batch_update() None[source]

Update the model after processing a batch of data.

Return type:

None

abstractmethod post_epoch_update() None[source]

Update the model after processing an epoch of data.

Return type:

None

class MonitorProtocol(*args, **kwargs)[source]

Bases: Protocol

Protocol for a class that validates a model.

model

the model to evaluate.

Type:

drytorch.core.protocols.ModelProtocol[Any, Any]

property name: str

The name of the model.

property computed_metrics: Mapping[str, float]

Computed metric values.

class ObjectiveProtocol(*args, **kwargs)[source]

Bases: Protocol[_Output_contra, _Target_contra]

Protocol that calculates and returns metrics.

abstractmethod update(outputs: _Output_contra, targets: _Target_contra, /) Any[source]

Compute the metrics only.

Parameters:
  • outputs (_Output_contra) – model outputs.

  • targets (_Target_contra) – ground truth.

Return type:

Any

abstractmethod compute() Mapping[str, Tensor] | Tensor | None[source]

Return a mapping from the metric names to the calculated values.

Return type:

Mapping[str, Tensor] | Tensor | None

abstractmethod reset() Any[source]

Reset cached values.

Return type:

Any

OutputType

alias of Any

class SchedulerProtocol(*args, **kwargs)[source]

Bases: Protocol

Protocol of a scheduler for the learning rate.

__call__(base_lr: float, epoch: int) float[source]

Modify the learning rate according to a schedule.

Parameters:
  • base_lr (float) – initial learning rate.

  • epoch (int) – the current epoch.

Returns:

The scheduled value for the learning rate.

Return type:

float

class TrainerProtocol(*args, **kwargs)[source]

Bases: MonitorProtocol, Protocol[_Input, _Target, _Output]

Protocol for a class that train a model.

model

the model to train.

Type:

drytorch.core.protocols.ModelProtocol[drytorch.core.protocols._Input, drytorch.core.protocols._Output]

learning_schema

contains optimizer settings and scheduling.

Type:

drytorch.core.protocols.LearningProtocol

objective

determines the optimization’s criterion.

Type:

drytorch.core.protocols.LossProtocol[drytorch.core.protocols._Output, drytorch.core.protocols._Target]

validation

class that validates the model,

Type:

drytorch.core.protocols.MonitorProtocol | None

property terminated: bool

If true, this trainer should not be used for training anymore.

train(n_epochs: int) None[source]

Train the module for the specified number of epochs.

Parameters:

n_epochs (int) – the number of epochs for which train the module.

Return type:

None

terminate_training(reason: str) None[source]

Prevent the trainer from continue the training.

Parameters:

reason (str)

Return type:

None

save_checkpoint() None[source]

Save model and optimizer state in a checkpoint.

Return type:

None

load_checkpoint(epoch: int = -1) None[source]

Load model and optimizer state from a checkpoint.

Parameters:

epoch (int)

Return type:

None

update_learning_rate(base_lr: float | None, scheduler: SchedulerProtocol | None) None[source]

Update the learning rate(s).

It updates the learning rates for each parameter’s group in the optimizer based on input learning rate and scheduler.

Parameters:
  • base_lr (float | None) – the initial learning rate.

  • scheduler (SchedulerProtocol | None) – scheduler for the learning rate.

Return type:

None