Source code for drytorch.lib.schedulers

"""Module containing schedulers for the learning rates."""

from __future__ import annotations

import abc
import dataclasses

from collections.abc import Callable, Iterable

import numpy as np

from typing_extensions import override

from drytorch.core import protocols as p


__all__ = [
    'AbstractScheduler',
    'ConstantScheduler',
    'CosineScheduler',
    'ExponentialScheduler',
    'PolynomialScheduler',
    'StepScheduler',
    'rescale',
    'restart',
    'warmup',
]


[docs] class AbstractScheduler(p.SchedulerProtocol, abc.ABC): """Abstract class for the scheduler."""
[docs] def __call__(self, base_lr: float, epoch: int) -> float: """Modify the learning rate according to a schedule. Args: base_lr: initial learning rate. epoch: the current epoch. Returns: scheduled value for the learning rate. Raises: ValueError: if base_lr or epoch are non-positive. """ if base_lr < 0 or epoch < 0: raise ValueError('Base learning rate and epoch must be positive.') return self._compute(base_lr, epoch)
[docs] def bind( self, f: Callable[[AbstractScheduler], AbstractScheduler], /, ) -> AbstractScheduler: """Allow transformation of the scheduler. Args: f: a function specifying the transformation. Returns: the transformed scheduler. """ return f(self)
@abc.abstractmethod def _compute(self, start_value: float, epoch: int) -> float: """Compute the scheduled value. Args: start_value: value when epoch is 0. epoch: variable of the function. Returns: the value for learning rate to use. """
# Need frozen=True to have it as the default value
[docs] @dataclasses.dataclass(frozen=True) class ConstantScheduler(AbstractScheduler): """Constant learning rate.""" @override def _compute(self, start_value: float, epoch: int) -> float: return start_value
[docs] @dataclasses.dataclass(frozen=True) class PolynomialScheduler(AbstractScheduler): """Polynomial learning rate scheduler: f(x) = C0 + C1(1 - x/C2)^C3. C0, C1, C2, C3 are defined so that: - f(x) = base_value when epoch = 0, - f(x) = min value when epoch is C2 = number of decay steps and, - f(x) is a polynomial of degree C3. After the number of decay steps, returns min value. Attributes: max_epochs: maximum number of epochs. power: polynomial power. min_decay: minimum fraction of the initial learning rate. """ max_epochs: int = 1000 power: float = 1.0 min_decay: float = 0.0
[docs] def __post_init__(self): """Input Validation.""" if self.max_epochs <= 0: raise ValueError('max_epochs must be positive.') if self.power < 0: raise ValueError('power must be non-negative.') if not 0 <= self.min_decay <= 1: raise ValueError('min_decay must be between 0 and 1.')
@override def _compute(self, start_value: float, epoch: int) -> float: if epoch >= self.max_epochs: return self.min_decay * start_value decay_factor = (1 - epoch / self.max_epochs) ** self.power return self.min_decay + decay_factor * (1 - self.min_decay)
[docs] @dataclasses.dataclass(frozen=True) class ExponentialScheduler(AbstractScheduler): """Schedule exponential decay: f(x) = C0 + C1(C2^x). C0, C1 and C2 are defined so that: - f(x) = base_value when epoch = 0, - f(x) = min value when the epoch goes to infinite and, - f(x) is an exponential function with decay factor C2. After the number of decay steps, returns min value. Attributes: exp_decay: exponential decay parameter d for the curve: f(x) = Cd^x. min_decay: proportion of base learning rate for the minimum CO. """ exp_decay: float = 0.975 min_decay: float = 0.00
[docs] def __post_init__(self): """Input Validation.""" if not 0 < self.exp_decay <= 1: raise ValueError('exp_decay must be positive and less than 1.') if not 0 <= self.min_decay <= 1: raise ValueError('min_decay must be between 0 and 1.')
@override def _compute(self, start_value: float, epoch: int) -> float: min_value = self.min_decay * start_value return (start_value - min_value) * self.exp_decay**epoch + min_value
[docs] @dataclasses.dataclass(frozen=True) class CosineScheduler(AbstractScheduler): """Schedule cosine decay: f(x) = C0 + C1(1 + cos(πx/C2)). C0, C1 and C2 are defined so that: - f(x) = base_value when epoch = 0 and, - f(x) = min value when epoch is C2 = number of decay steps. After the number of decay steps, returns min value. Attributes: decay_steps: number of steps (epochs) to reach maximum decay. min_decay: fraction of base_value for the minimum value. """ decay_steps: int = 250 min_decay: float = 0.01
[docs] def __post_init__(self): """Input Validation.""" if self.decay_steps <= 0: raise ValueError('decay_steps must be positive.') if not 0 <= self.min_decay <= 1: raise ValueError('min_decay must be between 0 and 1.')
@override def _compute(self, start_value: float, epoch: int) -> float: min_lr = self.min_decay * start_value if epoch > self.decay_steps: return min_lr from_1_to_minus1 = np.cos(np.pi * epoch / self.decay_steps) return min_lr + (start_value - min_lr) * (1 + from_1_to_minus1) / 2
@dataclasses.dataclass(frozen=True) class RescaleScheduler(AbstractScheduler): """Scale the output of an existing scheduler. Attributes: factor: factor that rescales the value. base_scheduler: the scheduler to call. """ base_scheduler: p.SchedulerProtocol factor: float def __post_init__(self): """Input Validation.""" if self.factor <= 0: raise ValueError('factor must be positive.') @override def _compute(self, start_value: float, epoch: int) -> float: return self.factor * self.base_scheduler(start_value, epoch) @dataclasses.dataclass(frozen=True) class RestartScheduler(AbstractScheduler): """Wraps another scheduler to provide periodic restarts. Attributes: base_scheduler: the scheduler to restart. restart_interval: the number of epochs between restarts. restart_fraction: fraction of the base value to use as the base value when restarting. max_restart: Maximum number of restarts before deactivating. Default never deactivates. """ base_scheduler: p.SchedulerProtocol restart_interval: int restart_fraction: float = 1.0 max_restart: int | None = None def __post_init__(self): """Input Validation.""" if self.restart_interval <= 0: raise ValueError('restart_interval must be positive.') if self.restart_fraction <= 0: raise ValueError('restart_fraction must be positive.') if self.max_restart is not None and self.max_restart <= 0: raise ValueError('max_restart must be positive.') @override def _compute(self, start_value: float, epoch: int) -> float: if epoch >= self.restart_interval: n_restart, restarted_epoch = divmod(epoch, self.restart_interval) if self.max_restart is None or n_restart <= self.max_restart: if restarted_epoch: start_value *= self.restart_fraction epoch = restarted_epoch else: epoch = self.restart_interval return self.base_scheduler(start_value, epoch) @dataclasses.dataclass(frozen=True) class WarmupScheduler(AbstractScheduler): """Adds a warmup phase to any scheduler. During warmup, the learning rate increases linearly from 0 to base_lr. After warmup, delegates to the wrapped scheduler with adjusted epochs. Attributes: warmup_steps: number of steps (epochs) for the linear warmup phase. base_scheduler: the base scheduler to wrap with warmup. """ base_scheduler: p.SchedulerProtocol warmup_steps: int = 10 def __post_init__(self): """Input Validation.""" if self.warmup_steps < 0: raise ValueError('warmup_steps must be non-negative.') @override def _compute(self, start_value: float, epoch: int) -> float: if epoch < self.warmup_steps: return start_value * (epoch / self.warmup_steps) return self.base_scheduler(start_value, epoch - self.warmup_steps) @override def __repr__(self) -> str: wrapped_repr = self.base_scheduler.__repr__() return f'{wrapped_repr} with {self.warmup_steps} warm-up steps'
[docs] @dataclasses.dataclass(frozen=True) class StepScheduler(AbstractScheduler): """Step-wise learning rate scheduler. Reduces learning rate by a factor at specified milestones. Attributes: milestones: iterable of epochs at which to reduce the learning rate. gamma: factor by which to reduce learning rate. """ milestones: Iterable[int] = dataclasses.field(default_factory=lambda: [200]) gamma: float = 0.1
[docs] def __post_init__(self): """Input Validation.""" if not all(m > 0 for m in self.milestones): raise ValueError('All milestones must be positive.') if self.milestones != sorted(self.milestones): raise ValueError('Milestones must be in ascending order.') if not 0 < self.gamma <= 1: raise ValueError('gamma must be between 0 and 1 (exclusive of 0).')
@override def _compute(self, start_value: float, epoch: int) -> float: count = sum(1 for milestone in self.milestones if epoch >= milestone) return start_value * (self.gamma**count)
# Decorator functions for functional composition
[docs] def rescale(factor: float) -> Callable[[AbstractScheduler], AbstractScheduler]: """Create a scaling transformation. Args: factor: factor that rescales the value. Returns: A decorator that wraps a scheduler with scaling. """ def _decorator(scheduler: AbstractScheduler) -> AbstractScheduler: return RescaleScheduler(scheduler, factor) return _decorator
[docs] def restart( restart_interval: int, restart_fraction: float = 1.0, ) -> Callable[[AbstractScheduler], AbstractScheduler]: """Create a restart transformation. Args: restart_interval: number of epochs between restarts. restart_fraction: fraction to use when restarting. Returns: A decorator that wraps a scheduler with restarts. """ def _decorator(scheduler: AbstractScheduler) -> AbstractScheduler: return RestartScheduler(scheduler, restart_interval, restart_fraction) return _decorator
[docs] def warmup( warmup_steps: int = 10, ) -> Callable[[AbstractScheduler], AbstractScheduler]: """Create a warmup transformation. Args: warmup_steps: number of warmup steps. Returns: A decorator that wraps a scheduler with warmup. """ def _decorator(scheduler: AbstractScheduler) -> AbstractScheduler: return WarmupScheduler(scheduler, warmup_steps) return _decorator