Source code for drytorch.lib.train

"""Module containing classes for training a model."""

import warnings

from typing import Final, Self, TypeVar

import torch

from typing_extensions import override

from drytorch.core import exceptions, log_events
from drytorch.core import protocols as p
from drytorch.lib import evaluations, hooks, models, runners


__all__ = [
    'Trainer',
]


Input = TypeVar('Input', bound=p.InputType)
Target = TypeVar('Target', bound=p.TargetType)
Output = TypeVar('Output', bound=p.OutputType)


[docs] class Trainer( runners.ModelRunnerWithLogs[ Input, Target, Output, p.LossProtocol[Output, Target] ], p.TrainerProtocol[Input, Target, Output], ): """Implement the standard Pytorch training loop. Attributes: model: the model to train. loader: provides inputs and targets in batches. objective: determines the optimization's criterion. learning_schema: contains optimizer settings and scheduling. validation: class that validates the model, """ def __init__( self, model: p.ModelProtocol[Input, Output], name: str = '', *, loader: p.LoaderProtocol[tuple[Input, Target]], loss: p.LossProtocol[Output, Target], learning_schema: p.LearningProtocol, ) -> None: """Initialize. Args: model: the model containing the weights to evaluate. name: the base name for the object for logging purposes. Defaults to class name plus eventual counter. loader: provides inputs and targets in batches. loss: determines the optimization's criterion. learning_schema: contains optimizer settings and scheduling. """ super().__init__(model, loader=loader, objective=loss, name=name) self.learning_schema: Final = learning_schema self.validation: p.MonitorProtocol | None = None self._model_optimizer: Final = models.ModelOptimizer( model, learning_schema ) self.pre_epoch_hooks: Final = hooks.HookRegistry[ Trainer[Input, Target, Output] ]() self.post_epoch_hooks: Final = hooks.HookRegistry[ Trainer[Input, Target, Output] ]() self._terminated = False return @property @override def terminated(self) -> bool: return self._terminated
[docs] @override def __call__(self, store_outputs: bool = False) -> None: """Train the module for one epoch. Args: store_outputs: whether to store model outputs. """ if self.terminated: warnings.warn(exceptions.TerminatedTrainingWarning(), stacklevel=1) return self.model.module.train() self.model.increment_epoch() self._model_optimizer.update_learning_rate() try: super().__call__() except exceptions.ConvergenceError as ce: self.terminate_training(reason=str(ce)) raise ce return
[docs] def add_validation( self, val_loader: p.LoaderProtocol[tuple[Input, Target]], interval: int = 1, ) -> None: """Add a loader for validation with the same metrics as for training. If different validation loaders are added, they will all be performed, but only the last will be stored as the instance validation. Args: val_loader: the loader for validation. interval: the frequency of validation. Raises: ValueError: if the interval is not strictly positive. """ validation = evaluations.Validation( self.model, loader=val_loader, metric=self.objective ) val_hook = hooks.StaticHook(validation) if interval < 1: raise ValueError(f'Interval must larger than 0. Got {interval}.') if interval > 1: val_hook.bind(hooks.call_every(interval)) self.post_epoch_hooks.register(val_hook) self.validation = validation return
[docs] @override def load_checkpoint(self, epoch: int = -1) -> None: """Load model and optimizer state from a checkpoint. Args: epoch: the epoch from which to load the checkpoint. Defaults to the last saved epoch. """ self._model_optimizer.load(epoch=epoch) return
[docs] @override def save_checkpoint(self) -> None: self._model_optimizer.save()
[docs] @override def terminate_training(self, reason: str) -> None: self._terminated = True log_events.TerminatedTrainingEvent( source_name=self.name, model_name=self.model.name, epoch=self.model.epoch, reason=reason, ) return
[docs] @override def train(self, n_epochs: int) -> None: if self.terminated: warnings.warn(exceptions.TerminatedTrainingWarning(), stacklevel=1) return final_epoch = self.model.epoch + n_epochs log_events.StartTrainingEvent( source_name=self.name, model_name=self.model.name, start_epoch=self.model.epoch, end_epoch=final_epoch, ) for _ in range(n_epochs): log_events.StartEpochEvent( source_name=self.name, model_name=self.model.name, epoch=self.model.epoch + 1, end_epoch=final_epoch, ) self.pre_epoch_hooks.execute(self) self() self.post_epoch_hooks.execute(self) log_events.EndEpochEvent( source_name=self.name, model_name=self.model.name, epoch=self.model.epoch, ) if self.terminated: break log_events.EndTrainingEvent(self.name) return
[docs] def train_until(self: Self, epoch: int) -> None: """Train the module until the specified epoch. Args: epoch: the final epoch in the training. """ remaining_epochs = epoch - self.model.epoch if remaining_epochs > 0: self.train(remaining_epochs) if remaining_epochs < 0: warnings.warn( exceptions.PastEpochWarning(epoch, self.model.epoch), stacklevel=1, ) return
[docs] @override def update_learning_rate( self, base_lr: float | dict[str, float] | None = None, scheduler: p.SchedulerProtocol | None = None, ) -> None: """Update the learning rate(s). It updates the learning rates for each parameter's group in the optimizer based on input learning rate(s) and scheduler. Args: base_lr: initial learning rates for named parameters or global value. Default keeps the original learning rates. scheduler: scheduler for the learning rates. Default keeps the original scheduler. """ scheduler_name = None if scheduler is None else repr(scheduler) log_events.LearningRateEvent( model_name=self.model.name, source_name=self.name, epoch=self.model.epoch, base_lr=base_lr, scheduler_name=scheduler_name, ) self._model_optimizer.update_learning_rate(base_lr, scheduler) return
@override def _run_backward(self, outputs: Output, targets: Target) -> None: # replace super call loss_value = self.objective.forward(outputs, targets) try: if torch.isinf(loss_value) or torch.isnan(loss_value): raise exceptions.ConvergenceError(loss_value.item()) except RuntimeError as re: if loss_value.numel() != 1: raise exceptions.LossNotScalarError(loss_value.shape) from re raise re self._model_optimizer.optimize(loss_value) self.model.update_parameters() return