Source code for drytorch.lib.gradient_ops

"""Module containing gradient operations."""

import abc
import copy
import math

from collections import defaultdict
from collections.abc import Callable, Iterable
from typing import ClassVar, Final, TypeAlias

import torch

from typing_extensions import override

from drytorch.core import protocols as p
from drytorch.core.protocols import GradientOpProtocol


__all__ = [
    'ClippingCriterion',
    'EMACriterion',
    'GradNormClipper',
    'GradParamNormalizer',
    'GradValueClipper',
    'GradZScoreNormalizer',
    'HistClipper',
    'NoOp',
    'ParamHistClipper',
    'StatsCollector',
    'ZStatCriterion',
    'max_clipping',
    'mean_clipping',
    'reciprocal_clipping',
]


ClipFunction: TypeAlias = Callable[[float, float], float]


def _validate_threshold(threshold: float) -> None:
    if threshold <= 0:
        raise ValueError('Gradient threshold must be positive.')


[docs] class NoOp(GradientOpProtocol): """Placeholder performing no gradient action."""
[docs] def __call__(self, params: Iterable[torch.nn.Parameter]) -> None: """No operation is performed.""" return
[docs] class GradParamNormalizer(p.GradientOpProtocol): """Strategy that normalizes each parameter's gradient to unit norm.""" _eps = 1e-8
[docs] def __call__(self, params: Iterable[torch.nn.Parameter]) -> None: """Normalize gradients to unit norm in-place.""" for param in params: grad = param.grad if grad is None: continue norm = grad.norm(2) grad.div_(norm + self._eps) return
[docs] class GradZScoreNormalizer(p.GradientOpProtocol): """Gradient normalizing strategy using Z-score normalization.""" _eps = 1e-8
[docs] def __call__(self, params: Iterable[torch.nn.Parameter]) -> None: """Normalize gradients using Z-score in-place.""" for param in params: grad = param.grad if grad is None: continue mean = grad.mean() std = grad.std(unbiased=True) grad.sub_(mean).div_(std + self._eps)
class ClipOperation(p.GradientOpProtocol, abc.ABC): """Abstract base class for gradient operations.""" @abc.abstractmethod def __call__(self, params: Iterable[torch.nn.Parameter]) -> None: """Apply the gradient operation to the given parameters."""
[docs] class GradNormClipper(ClipOperation): """Gradient norm clipping strategy. Attributes: threshold: Maximum norm value of the clipped gradients. """ threshold: float def __init__(self, threshold: float = 1) -> None: """Initialize. Args: threshold: Maximum norm value of the clipped gradients. """ super().__init__() _validate_threshold(threshold) self.threshold = threshold return
[docs] def __call__(self, params: Iterable[torch.nn.Parameter]) -> None: """Clip gradients by norm in-place.""" torch.nn.utils.clip_grad_norm_(params, max_norm=self.threshold) return
[docs] class GradValueClipper(ClipOperation): """Gradient value clipping strategy. Attributes: threshold: Maximum absolute value of the clipped gradients. """ threshold: float def __init__(self, threshold: float = 1) -> None: """Initialize. Args: threshold: Maximum absolute value of the clipped gradients. """ super().__init__() _validate_threshold(threshold) self.threshold: Final = threshold return
[docs] def __call__(self, params: Iterable[torch.nn.Parameter]) -> None: """Clip gradients by value in-place.""" torch.nn.utils.clip_grad_value_(params, clip_value=self.threshold) return
[docs] def reciprocal_clipping(zt: float, z_thresh: float) -> float: """Reciprocal clipping as recommended in https://arxiv.org/pdf/2504.02507. Instead of clipping to the threshold value, reciprocal clipping decreases the norm of the gradient even further as the spike gets larger. Args: zt: the Z-statistic or ratio of the current gradient norm. z_thresh: the threshold for the z-statistic values. Returns: Renormalization factor (between 0 and 1). """ return z_thresh**2 / zt
[docs] def mean_clipping(zt: float, z_thresh: float) -> float: """Clip to the mean value (effectively setting gradient to running mean). Args: zt: the Z-statistic or ratio of the current gradient norm. z_thresh: the threshold for the z-statistic values. Returns: Renormalization factor of 0 (clips to mean). """ _not_used = zt, z_thresh return 0.0
[docs] def max_clipping(zt: float, z_thresh: float) -> float: """Standard clipping to the threshold value. Args: zt: the Z-statistic or ratio of the current gradient norm. z_thresh: the threshold for the z-statistic values. Returns: The threshold value as the renormalization factor. """ _not_used = zt return z_thresh
[docs] class ClippingCriterion(abc.ABC): """Criteria that detects when to clip snd determines the clipping value."""
[docs] @abc.abstractmethod def should_clip(self, value: float) -> bool: """Determine whether to clip gradients based on the current value. Args: value: current gradient norm or value to evaluate. Returns: True if gradients should be clipped, False otherwise. """
[docs] @abc.abstractmethod def get_clip_value(self, value: float) -> float: """Calculate the clipping threshold based on current statistics. Args: value: Current gradient norm or value. Returns: The value to clip gradients to. """
[docs] def update(self, value: float) -> None: """Update internal statistics with a new observed value. Args: value: new gradient norm or value to incorporate. """ _unused = value return
[docs] def set_statistics(self, mean: float, variance: float = 0.0) -> None: """Initialize statistics from warmup data. Args: mean: mean value from the warmup period. variance: variance from the warmup period (if applicable). """ _unused = mean, variance return
[docs] def reset(self) -> None: """Reset all internal statistics to initial state.""" return
[docs] class EMACriterion(ClippingCriterion): """Clipping criterion based on Exponential Moving Average. It uses only the running mean of gradient norms to detect outliers. It clips when the current norm exceeds the mean by a factor of r_thresh. Attributes: alpha: exponential moving average decay factor. r_thresh: ratio threshold between current_norm and mean_norm. clipping_function: function to determine clipping behavior. """ alpha: float r_thresh: float clipping_function: ClipFunction _mu_t: float def __init__( self, alpha: float = 0.98, r_thresh: float = 1.05, clipping_function: ClipFunction = max_clipping, ): """Initialize. Args: alpha: exponential moving average decay factor. r_thresh: ratio threshold between current_norm and mean_norm. clipping_function: function to determine clipping behavior. """ self.alpha: Final = alpha self.r_thresh: Final = r_thresh self.clipping_function: Final = clipping_function self._mu_t = 0.0 _validate_threshold(r_thresh) return
[docs] @override def should_clip(self, value: float) -> bool: if self._mu_t == 0.0: return False return value / self._mu_t > self.r_thresh
[docs] @override def get_clip_value(self, value: float) -> float: if self._mu_t == 0.0: return value ratio = value / self._mu_t clipping_factor = self.clipping_function(ratio, self.r_thresh) return self._mu_t * clipping_factor
[docs] @override def update(self, value: float) -> None: self._mu_t = self.alpha * self._mu_t + (1 - self.alpha) * value return
[docs] @override def set_statistics(self, mean: float, variance: float = 0.0) -> None: self._mu_t = mean return
[docs] @override def reset(self) -> None: self._mu_t = 0.0 return
[docs] class ZStatCriterion(ClippingCriterion): """Clipping criterion based on the Z-statistic. Tracks both mean and variance using exponential moving averages. The clipping threshold is on the Z-score (standardized deviation). See also https://arxiv.org/pdf/2504.02507. Attributes: alpha: exponential moving average decay factor (0 < alpha < 1). z_thresh: Z-score threshold between !z_score| and z_thresh. clipping_function: function to determine clipping behavior. """ alpha: float z_thresh: float clipping_function: ClipFunction _eps: float _mu_t: float _v_t: float _eps = 1e-8 def __init__( self, alpha: float = 0.97, z_thresh: float = 2.5, clipping_function: ClipFunction = reciprocal_clipping, ): """Initialize. Args: alpha: exponential moving average decay factor (0 < alpha < 1). z_thresh: threshold for the Z-score. clipping_function: function to determine clipping behavior. """ self.alpha: Final = alpha self.z_thresh: Final = z_thresh self.clipping_function: Final = clipping_function self._mu_t = 0.0 self._v_t = 1.0 _validate_threshold(z_thresh) return
[docs] @override def should_clip(self, value: float) -> bool: """Check if the Z-score exceeds the threshold.""" if self._mu_t == 0.0: return False z_score = (value - self._mu_t) / (math.sqrt(self._v_t) + self._eps) return abs(z_score) > self.z_thresh
[docs] @override def get_clip_value(self, value: float) -> float: if self._mu_t == 0.0: return value z_score = (value - self._mu_t) / (math.sqrt(self._v_t) + self._eps) if abs(z_score) <= self.z_thresh: return value new_z_score = self.clipping_function(abs(z_score), self.z_thresh) return self._mu_t + new_z_score * math.sqrt(self._v_t)
[docs] @override def update(self, value: float) -> None: variance = (value - self._mu_t) ** 2 self._v_t = self.alpha * self._v_t + (1 - self.alpha) * variance self._mu_t = self.alpha * self._mu_t + (1 - self.alpha) * value return
[docs] @override def set_statistics(self, mean: float, variance: float = 0.0) -> None: self._mu_t = mean if variance > 0: self._v_t = variance return
[docs] @override def reset(self) -> None: self._mu_t = 0.0 self._v_t = 1.0 return
[docs] class StatsCollector: """Initialize. Attributes: max_samples: the number of collected samples for completion. active: whether the collector is currently in use. """ max_samples: int _data: list[float] active: bool def __init__(self, max_samples: int): """Initialize warmup handler. Args: max_samples: the number of collected samples for completion. """ self.max_samples: Final = max_samples self._data = [] self.active = True return
[docs] def __len__(self) -> int: """Return the number of collected warmup samples.""" return len(self._data)
@property def mean(self) -> float: """Calculate mean of collected warmup samples.""" if not self._data: return 0.0 return sum(self._data) / len(self._data) @property def variance(self) -> float: """Calculate variance of collected samples.""" if len(self._data) <= 1: return 1.0 mean_val = self.mean variance_sum = sum((x - mean_val) ** 2 for x in self._data) return variance_sum / (len(self._data) - 1)
[docs] def is_complete(self) -> bool: """Check if the collection is complete.""" completed = len(self) >= self.max_samples if completed: self.active = False return completed
[docs] def append(self, value: float) -> None: """Add a new datum to the collection.""" if len(self) < self.max_samples: self._data.append(value) return
[docs] def reset(self) -> None: """Reset the collection.""" self._data.clear() self.active = True return
[docs] class HistClipper(ClipOperation): """Global gradient clipping strategy that uses previous gradient statistics. The gradients' norm is renormalized according to a clipping criterion. Attributes: criterion: the clipping criterion to determine when and how to clip. warmup_clip_strategy: the clipping strategy used during warmup. n_warmup_steps: the number of warmup steps to collect initial stats. """ _default_criterion: ClassVar[ZStatCriterion] = ZStatCriterion() _default_grad_op: ClassVar[GradNormClipper] = GradNormClipper() criterion: ClippingCriterion warmup_clip_strategy: GradientOpProtocol n_warmup_steps: int _warmup_handler: StatsCollector def __init__( self, criterion: ClippingCriterion = _default_criterion, warmup_clip_strategy: p.GradientOpProtocol = _default_grad_op, n_warmup_steps: int = 20, ) -> None: """Initialize. Args: criterion: the clipping criterion to determine when and how to clip. warmup_clip_strategy: the clipping strategy used during warmup. n_warmup_steps: the number of warmup steps to collect initial stats. """ super().__init__() self.criterion: Final = criterion self.warmup_clip_strategy: Final = warmup_clip_strategy self.n_warmup_steps: Final = n_warmup_steps self._warmup_handler: Final = StatsCollector(n_warmup_steps) return
[docs] def __call__(self, params: Iterable[torch.nn.Parameter]) -> None: """Apply global gradient clipping. Args: params: model parameters to clip. Side Effects: Modifies gradients in-place if clipping is applied. """ # needed to allow multiple iterations params_list = list(params) squared_norms = [ (param.grad**2).sum().item() for param in params_list if param.grad is not None ] global_norm = math.sqrt(sum(squared_norms)) if self._warmup_handler.active: if not self._warmup_handler.is_complete(): self._warmup_handler.append(global_norm) self.warmup_clip_strategy(params_list) return else: self.criterion.set_statistics( self._warmup_handler.mean, self._warmup_handler.variance ) if self.criterion.should_clip(global_norm): clip_value = self.criterion.get_clip_value(global_norm) torch.nn.utils.clip_grad_norm_(params_list, clip_value) self.criterion.update(global_norm) return
[docs] def reset(self): """Reset the state.""" self._warmup_handler.reset() self.criterion.reset() return
[docs] class ParamHistClipper(ClipOperation): """Gradient clipping strategy that keeps per-parameter statistics. The gradients' norm is renormalized according to a clipping criterion. Attributes: criterion: the clipping criterion to determine when and how to clip. warmup_clip_strategy: the clipping strategy used during warmup. n_warmup_steps: the number of warmup steps to collect initial stats. """ _default_criterion: ClassVar[ZStatCriterion] = ZStatCriterion() _default_grad_op: ClassVar[GradNormClipper] = GradNormClipper() criterion: ClippingCriterion n_warmup_steps: int warmup_clip_strategy: GradientOpProtocol _dict_criterion: defaultdict[int, ClippingCriterion] _dict_warmup_handler: defaultdict[int, StatsCollector] def __init__( self, criterion: ClippingCriterion = _default_criterion, warmup_clip_strategy: p.GradientOpProtocol = _default_grad_op, n_warmup_steps: int = 20, ) -> None: """Initialize. Args: criterion: the clipping criterion to determine when and how to clip. warmup_clip_strategy: the clipping strategy used during warmup. n_warmup_steps: the number of warmup steps to collect initial stats. """ super().__init__() self.criterion: Final = criterion self.n_warmup_steps: Final = n_warmup_steps self.warmup_clip_strategy: Final = warmup_clip_strategy self._dict_criterion = defaultdict(lambda: copy.copy(criterion)) self._dict_warmup_handler = defaultdict( lambda: StatsCollector(n_warmup_steps) ) return
[docs] def __call__(self, params: Iterable[torch.nn.Parameter]) -> None: """Apply global gradient clipping. Args: params: Model parameters to clip. Side Effects: Modifies gradients in-place if clipping is applied. """ for param in params: grad = param.grad if grad is None: continue grad_norm: float = grad.norm(2, dtype=float).item() param_id = id(param) warmup_handler = self._dict_warmup_handler[param_id] criterion = self._dict_criterion[param_id] if warmup_handler.active: if not warmup_handler.is_complete(): warmup_handler.append(grad_norm) self.warmup_clip_strategy([param]) continue else: criterion.set_statistics( warmup_handler.mean, warmup_handler.variance ) if criterion.should_clip(grad_norm): clip_value = criterion.get_clip_value(grad_norm) torch.nn.utils.clip_grad_norm_([param], clip_value) criterion.update(grad_norm) return
[docs] def reset(self): """Reset the state.""" self._dict_criterion.clear() self._dict_warmup_handler.clear() return