drytorch.lib.checkpoints

Module containing classes to save the model state and its optimizer state.

Classes

AbstractCheckpoint()

Abstract class that stores and loads weight for a ModelProtocol class.

CheckpointPathManager(model[, run_dir])

Manage paths for the experiment.

LocalCheckpoint([par_dir])

Manage locally saving and loading the model state and optimizer.

class AbstractCheckpoint[source]

Bases: CheckpointProtocol, ABC

Abstract class that stores and loads weight for a ModelProtocol class.

Initialize.

property model

The registered model to be saved and loaded.

Raises:

CheckpointNotInitializedError – if no model has been bound.

property optimizer: Optimizer | None

The registered optimizer for the model.

load(epoch: int = -1) None[source]

Load the model and optimizer state dictionaries.

Parameters:

epoch (int) – epoch to load.

Raises:
Return type:

None

remove_model() None[source]

Remove registered model.

Return type:

None

bind_model(model: ModelProtocol[Any, Any]) None[source]

Bind the model to manage.

Parameters:

model (ModelProtocol[Any, Any])

Return type:

None

bind_module(name: str, module: Module) None[source]

Bind a module connected to the model.

Parameters:
Return type:

None

bind_optimizer(optimizer: Optimizer) None[source]

Bind the optimizer connected to the model.

Parameters:

optimizer (Optimizer)

Return type:

None

save() None[source]

Save the model and optimizer state dictionaries.

Return type:

None

class LocalCheckpoint(par_dir: Path | None = None)[source]

Bases: AbstractCheckpoint

Manage locally saving and loading the model state and optimizer.

Initialize.

Parameters:

par_dir (Path | None) – parent directory for experiment data.

property paths: CheckpointPathManager

Path manager for directories and checkpoints.

save() None[source]

Save the model and optimizer state dictionaries.

Return type:

None