"""Module containing classes to save the model state and its optimizer state."""
import abc
import codecs
import pathlib
import warnings
from pathlib import Path
from typing import Any, ClassVar, Final
import numpy as np
import torch
from torch import distributed as dist
from typing_extensions import override
from drytorch.core import exceptions, experimenting, log_events
from drytorch.core import protocols as p
__all__ = [
'AbstractCheckpoint',
'LocalCheckpoint',
]
SAFE_GLOBALS: list[Any] = [
np.bool_,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
np.float16,
np.float32,
np.float64,
np.complex64,
np.complex128,
np.dtype,
codecs.encode,
]
try:
from numpy._core.multiarray import scalar # type: ignore # pyright: ignore
except ImportError:
pass
else:
SAFE_GLOBALS.append(scalar)
SAFE_GLOBALS.extend([getattr(np.dtypes, name) for name in np.dtypes.__all__])
torch.serialization.add_safe_globals(SAFE_GLOBALS)
class CheckpointPathManager:
"""Manage paths for the experiment.
Class Attributes:
folder_name: name of the folder where the checkpoints are stored.
"""
folder_name: ClassVar[str] = 'checkpoints'
_model: p.ModelProtocol[Any, Any]
_run_dir: Path | None
def __init__(
self,
model: p.ModelProtocol[Any, Any],
run_dir: pathlib.Path | None = None,
) -> None:
"""Initialize.
Args:
model: the model whose paths are to be managed.
run_dir: the directory for experiment data.
"""
self._model: Final = model
self._run_dir = run_dir
@property
def run_dir(self) -> pathlib.Path:
"""Parent directory for the checkpoints."""
if self._run_dir is None:
try:
exp = experimenting.Experiment[Any].get_current()
except exceptions.NoActiveExperimentError as naee:
raise exceptions.AccessOutsideScopeError from naee
else:
exp_dir = exp.par_dir / self.folder_name / exp.name
if '@' in exp.run.id:
day, time = exp.run.id.split('@')
return exp_dir / day / time
else:
return exp_dir / exp.run.id
return self._run_dir
@property
def model_dir(self) -> pathlib.Path:
"""Directory for the model."""
model_dir = self.run_dir / self._model.name
return model_dir
@property
def epoch_dir(self) -> pathlib.Path:
"""Directory for a checkpoint at the current epoch."""
epoch_directory = self.model_dir / f'epoch_{self._model.epoch}'
return epoch_directory
def get_model_state_path(self, module_name: str) -> pathlib.Path:
"""Get the name of the file with the model state."""
return self.epoch_dir / f'{module_name}_state.pt'
def get_optimizer_state_path(self) -> pathlib.Path:
"""Get the name of the file with the optimizer state."""
return self.epoch_dir / 'optimizer_state.pt'
[docs]
class AbstractCheckpoint(p.CheckpointProtocol, abc.ABC):
"""Abstract class that stores and loads weight for a ModelProtocol class."""
_model: p.ModelProtocol[Any, Any] | None
_optimizer: torch.optim.Optimizer | None
_modules: dict[str, Any]
def __init__(self) -> None:
"""Initialize."""
self._model = None
self._optimizer = None
self._modules = {}
return
@property
def model(self):
"""The registered model to be saved and loaded.
Raises:
CheckpointNotInitializedError: if no model has been bound.
"""
if self._model is None:
raise exceptions.CheckpointNotInitializedError()
return self._model
@property
def optimizer(self) -> torch.optim.Optimizer | None:
"""The registered optimizer for the model."""
return self._optimizer
[docs]
def load(self, epoch: int = -1) -> None:
"""Load the model and optimizer state dictionaries.
Args:
epoch: epoch to load.
Raises:
ModelNotFoundError: if the model location does not exist.
EpochNotFoundError: if the epoch location does not exist.
"""
if dist.is_available and dist.is_initialized():
device_idx = self.model.device.index
if device_idx is not None:
dist.barrier(device_ids=[device_idx])
else:
dist.barrier()
self._update_epoch(epoch)
self._check_location(epoch)
log_events.LoadModelEvent(
model_name=self.model.name,
definition=self._get_definition(),
location=self._get_location(),
epoch=self.model.epoch,
)
for name, module in self._modules.items():
self._load_module(name, module, epoch)
self._load_optimizer(epoch)
return
[docs]
def remove_model(self) -> None:
"""Remove registered model."""
self._model = None
self._optimizer = None
self._modules.clear()
return
[docs]
def bind_model(self, model: p.ModelProtocol[Any, Any]) -> None:
"""Bind the model to manage."""
self._model = model
self.bind_module('model', model.module)
return
[docs]
def bind_module(self, name: str, module: torch.nn.Module) -> None:
"""Bind a module connected to the model."""
self._modules[name] = module
[docs]
def bind_optimizer(self, optimizer: torch.optim.Optimizer) -> None:
"""Bind the optimizer connected to the model."""
self._optimizer = optimizer
return
[docs]
def save(self) -> None:
"""Save the model and optimizer state dictionaries."""
log_events.SaveModelEvent(
model_name=self.model.name,
definition=self._get_definition(),
location=self._get_location(),
epoch=self.model.epoch,
)
if dist.is_available and dist.is_initialized() and dist.get_rank():
return
for name, module in self._modules.items():
self._save_module(name, module)
self._save_optimizer()
return
@abc.abstractmethod
def _check_location(self, epoch: int) -> None: ...
def _get_definition(self) -> str:
return 'state' if self.optimizer is None else 'checkpoint'
@abc.abstractmethod
def _get_last_saved_epoch(self) -> int: ...
@abc.abstractmethod
def _get_location(self) -> str: ...
@abc.abstractmethod
def _load_module(
self, name: str, module: torch.nn.Module, epoch: int
) -> None: ...
@abc.abstractmethod
def _load_optimizer(self, epoch: int) -> None: ...
@abc.abstractmethod
def _save_module(self, name: str, module: torch.nn.Module) -> None: ...
@abc.abstractmethod
def _save_optimizer(self) -> None: ...
def _update_epoch(self, epoch: int):
if epoch < -1:
raise ValueError('Epoch must be larger than -1.')
epoch = epoch if epoch >= 0 else self._get_last_saved_epoch()
self.model.epoch = epoch
[docs]
class LocalCheckpoint(AbstractCheckpoint):
"""Manage locally saving and loading the model state and optimizer."""
def __init__(self, par_dir: pathlib.Path | None = None) -> None:
"""Initialize.
Args:
par_dir: parent directory for experiment data.
"""
super().__init__()
self._par_dir = par_dir
return
@property
def paths(self) -> CheckpointPathManager:
"""Path manager for directories and checkpoints."""
return CheckpointPathManager(self.model, self._par_dir)
[docs]
@override
def save(self) -> None:
self.paths.epoch_dir.mkdir(exist_ok=True, parents=True)
super().save()
return
@override
def _check_location(self, epoch: int) -> None:
if not self.paths.model_dir.exists():
raise exceptions.ModelNotFoundError(self.paths.run_dir)
if not self.paths.epoch_dir.exists():
raise exceptions.EpochNotFoundError(epoch, self.paths.model_dir)
return
@override
def _load_module(
self, name: str, module: torch.nn.Module, epoch: int = -1
) -> None:
state_dict = torch.load(
self.paths.get_model_state_path(name),
map_location=self.model.device,
weights_only=True,
)
module.load_state_dict(state_dict)
return
@override
def _load_optimizer(self, epoch: int = -1) -> None:
if not self.paths.model_dir.exists():
raise exceptions.ModelNotFoundError(self.paths.run_dir)
if not self.paths.epoch_dir.exists():
raise exceptions.EpochNotFoundError(epoch, self.paths.model_dir)
if self.optimizer is not None:
try:
self.optimizer.load_state_dict(
torch.load(
self.paths.get_optimizer_state_path(),
map_location=self.model.device,
weights_only=True,
),
)
except ValueError as ve:
warnings.warn(
exceptions.OptimizerNotLoadedWarning(ve), stacklevel=1
)
return
@override
def _save_module(self, name: str, module: torch.nn.Module) -> None:
torch.save(module.state_dict(), self.paths.get_model_state_path(name))
return
@override
def _save_optimizer(self) -> None:
if self.optimizer is not None:
torch.save(
self.optimizer.state_dict(),
self.paths.get_optimizer_state_path(),
)
return
def _get_last_saved_epoch(self) -> int:
model_directory = self.paths.model_dir
if model_directory.exists():
all_epochs = [d for d in model_directory.iterdir() if d.is_dir()]
else:
all_epochs = []
if not all_epochs:
raise exceptions.ModelNotFoundError(model_directory)
last_epoch_dir = max(all_epochs, key=self._creation_time)
return self._get_epoch(last_epoch_dir)
def _get_location(self) -> str:
return str(self.paths.epoch_dir)
@staticmethod
def _creation_time(directory: pathlib.Path) -> float:
creation_time = 0.0
for file in directory.iterdir():
creation_time = max(creation_time, file.stat().st_ctime)
return creation_time
@staticmethod
def _get_epoch(directory: pathlib.Path) -> int:
return int(directory.stem.rsplit('_', 1)[-1])