"""Module containing gradient operations."""
import abc
import copy
import math
from collections import defaultdict
from collections.abc import Callable, Iterable
from typing import ClassVar, Final, TypeAlias
import torch
from typing_extensions import override
from drytorch.core import protocols as p
from drytorch.core.protocols import GradientOpProtocol
__all__ = [
'ClippingCriterion',
'EMACriterion',
'GradNormClipper',
'GradParamNormalizer',
'GradValueClipper',
'GradZScoreNormalizer',
'HistClipper',
'NoOp',
'ParamHistClipper',
'StatsCollector',
'ZStatCriterion',
'max_clipping',
'mean_clipping',
'reciprocal_clipping',
]
ClipFunction: TypeAlias = Callable[[float, float], float]
def _validate_threshold(threshold: float) -> None:
if threshold <= 0:
raise ValueError('Gradient threshold must be positive.')
[docs]
class NoOp(GradientOpProtocol):
"""Placeholder performing no gradient action."""
[docs]
def __call__(self, params: Iterable[torch.nn.Parameter]) -> None:
"""No operation is performed."""
return
[docs]
class GradParamNormalizer(p.GradientOpProtocol):
"""Strategy that normalizes each parameter's gradient to unit norm."""
_eps = 1e-8
[docs]
def __call__(self, params: Iterable[torch.nn.Parameter]) -> None:
"""Normalize gradients to unit norm in-place."""
for param in params:
grad = param.grad
if grad is None:
continue
norm = grad.norm(2)
grad.div_(norm + self._eps)
return
[docs]
class GradZScoreNormalizer(p.GradientOpProtocol):
"""Gradient normalizing strategy using Z-score normalization."""
_eps = 1e-8
[docs]
def __call__(self, params: Iterable[torch.nn.Parameter]) -> None:
"""Normalize gradients using Z-score in-place."""
for param in params:
grad = param.grad
if grad is None:
continue
mean = grad.mean()
std = grad.std(unbiased=True)
grad.sub_(mean).div_(std + self._eps)
class ClipOperation(p.GradientOpProtocol, abc.ABC):
"""Abstract base class for gradient operations."""
@abc.abstractmethod
def __call__(self, params: Iterable[torch.nn.Parameter]) -> None:
"""Apply the gradient operation to the given parameters."""
[docs]
class GradNormClipper(ClipOperation):
"""Gradient norm clipping strategy.
Attributes:
threshold: Maximum norm value of the clipped gradients.
"""
threshold: float
def __init__(self, threshold: float = 1) -> None:
"""Initialize.
Args:
threshold: Maximum norm value of the clipped gradients.
"""
super().__init__()
_validate_threshold(threshold)
self.threshold = threshold
return
[docs]
def __call__(self, params: Iterable[torch.nn.Parameter]) -> None:
"""Clip gradients by norm in-place."""
torch.nn.utils.clip_grad_norm_(params, max_norm=self.threshold)
return
[docs]
class GradValueClipper(ClipOperation):
"""Gradient value clipping strategy.
Attributes:
threshold: Maximum absolute value of the clipped gradients.
"""
threshold: float
def __init__(self, threshold: float = 1) -> None:
"""Initialize.
Args:
threshold: Maximum absolute value of the clipped gradients.
"""
super().__init__()
_validate_threshold(threshold)
self.threshold: Final = threshold
return
[docs]
def __call__(self, params: Iterable[torch.nn.Parameter]) -> None:
"""Clip gradients by value in-place."""
torch.nn.utils.clip_grad_value_(params, clip_value=self.threshold)
return
[docs]
def reciprocal_clipping(zt: float, z_thresh: float) -> float:
"""Reciprocal clipping as recommended in https://arxiv.org/pdf/2504.02507.
Instead of clipping to the threshold value, reciprocal clipping decreases
the norm of the gradient even further as the spike gets larger.
Args:
zt: the Z-statistic or ratio of the current gradient norm.
z_thresh: the threshold for the z-statistic values.
Returns:
Renormalization factor (between 0 and 1).
"""
return z_thresh**2 / zt
[docs]
def mean_clipping(zt: float, z_thresh: float) -> float:
"""Clip to the mean value (effectively setting gradient to running mean).
Args:
zt: the Z-statistic or ratio of the current gradient norm.
z_thresh: the threshold for the z-statistic values.
Returns:
Renormalization factor of 0 (clips to mean).
"""
_not_used = zt, z_thresh
return 0.0
[docs]
def max_clipping(zt: float, z_thresh: float) -> float:
"""Standard clipping to the threshold value.
Args:
zt: the Z-statistic or ratio of the current gradient norm.
z_thresh: the threshold for the z-statistic values.
Returns:
The threshold value as the renormalization factor.
"""
_not_used = zt
return z_thresh
[docs]
class ClippingCriterion(abc.ABC):
"""Criteria that detects when to clip snd determines the clipping value."""
[docs]
@abc.abstractmethod
def should_clip(self, value: float) -> bool:
"""Determine whether to clip gradients based on the current value.
Args:
value: current gradient norm or value to evaluate.
Returns:
True if gradients should be clipped, False otherwise.
"""
[docs]
@abc.abstractmethod
def get_clip_value(self, value: float) -> float:
"""Calculate the clipping threshold based on current statistics.
Args:
value: Current gradient norm or value.
Returns:
The value to clip gradients to.
"""
[docs]
def update(self, value: float) -> None:
"""Update internal statistics with a new observed value.
Args:
value: new gradient norm or value to incorporate.
"""
_unused = value
return
[docs]
def set_statistics(self, mean: float, variance: float = 0.0) -> None:
"""Initialize statistics from warmup data.
Args:
mean: mean value from the warmup period.
variance: variance from the warmup period (if applicable).
"""
_unused = mean, variance
return
[docs]
def reset(self) -> None:
"""Reset all internal statistics to initial state."""
return
[docs]
class EMACriterion(ClippingCriterion):
"""Clipping criterion based on Exponential Moving Average.
It uses only the running mean of gradient norms to detect outliers.
It clips when the current norm exceeds the mean by a factor of r_thresh.
Attributes:
alpha: exponential moving average decay factor.
r_thresh: ratio threshold between current_norm and mean_norm.
clipping_function: function to determine clipping behavior.
"""
alpha: float
r_thresh: float
clipping_function: ClipFunction
_mu_t: float
def __init__(
self,
alpha: float = 0.98,
r_thresh: float = 1.05,
clipping_function: ClipFunction = max_clipping,
):
"""Initialize.
Args:
alpha: exponential moving average decay factor.
r_thresh: ratio threshold between current_norm and mean_norm.
clipping_function: function to determine clipping behavior.
"""
self.alpha: Final = alpha
self.r_thresh: Final = r_thresh
self.clipping_function: Final = clipping_function
self._mu_t = 0.0
_validate_threshold(r_thresh)
return
[docs]
@override
def should_clip(self, value: float) -> bool:
if self._mu_t == 0.0:
return False
return value / self._mu_t > self.r_thresh
[docs]
@override
def get_clip_value(self, value: float) -> float:
if self._mu_t == 0.0:
return value
ratio = value / self._mu_t
clipping_factor = self.clipping_function(ratio, self.r_thresh)
return self._mu_t * clipping_factor
[docs]
@override
def update(self, value: float) -> None:
self._mu_t = self.alpha * self._mu_t + (1 - self.alpha) * value
return
[docs]
@override
def set_statistics(self, mean: float, variance: float = 0.0) -> None:
self._mu_t = mean
return
[docs]
@override
def reset(self) -> None:
self._mu_t = 0.0
return
[docs]
class ZStatCriterion(ClippingCriterion):
"""Clipping criterion based on the Z-statistic.
Tracks both mean and variance using exponential moving averages. The
clipping threshold is on the Z-score (standardized deviation). See also
https://arxiv.org/pdf/2504.02507.
Attributes:
alpha: exponential moving average decay factor (0 < alpha < 1).
z_thresh: Z-score threshold between !z_score| and z_thresh.
clipping_function: function to determine clipping behavior.
"""
alpha: float
z_thresh: float
clipping_function: ClipFunction
_eps: float
_mu_t: float
_v_t: float
_eps = 1e-8
def __init__(
self,
alpha: float = 0.97,
z_thresh: float = 2.5,
clipping_function: ClipFunction = reciprocal_clipping,
):
"""Initialize.
Args:
alpha: exponential moving average decay factor (0 < alpha < 1).
z_thresh: threshold for the Z-score.
clipping_function: function to determine clipping behavior.
"""
self.alpha: Final = alpha
self.z_thresh: Final = z_thresh
self.clipping_function: Final = clipping_function
self._mu_t = 0.0
self._v_t = 1.0
_validate_threshold(z_thresh)
return
[docs]
@override
def should_clip(self, value: float) -> bool:
"""Check if the Z-score exceeds the threshold."""
if self._mu_t == 0.0:
return False
z_score = (value - self._mu_t) / (math.sqrt(self._v_t) + self._eps)
return abs(z_score) > self.z_thresh
[docs]
@override
def get_clip_value(self, value: float) -> float:
if self._mu_t == 0.0:
return value
z_score = (value - self._mu_t) / (math.sqrt(self._v_t) + self._eps)
if abs(z_score) <= self.z_thresh:
return value
new_z_score = self.clipping_function(abs(z_score), self.z_thresh)
return self._mu_t + new_z_score * math.sqrt(self._v_t)
[docs]
@override
def update(self, value: float) -> None:
variance = (value - self._mu_t) ** 2
self._v_t = self.alpha * self._v_t + (1 - self.alpha) * variance
self._mu_t = self.alpha * self._mu_t + (1 - self.alpha) * value
return
[docs]
@override
def set_statistics(self, mean: float, variance: float = 0.0) -> None:
self._mu_t = mean
if variance > 0:
self._v_t = variance
return
[docs]
@override
def reset(self) -> None:
self._mu_t = 0.0
self._v_t = 1.0
return
[docs]
class StatsCollector:
"""Initialize.
Attributes:
max_samples: the number of collected samples for completion.
active: whether the collector is currently in use.
"""
max_samples: int
_data: list[float]
active: bool
def __init__(self, max_samples: int):
"""Initialize warmup handler.
Args:
max_samples: the number of collected samples for completion.
"""
self.max_samples: Final = max_samples
self._data = []
self.active = True
return
[docs]
def __len__(self) -> int:
"""Return the number of collected warmup samples."""
return len(self._data)
@property
def mean(self) -> float:
"""Calculate mean of collected warmup samples."""
if not self._data:
return 0.0
return sum(self._data) / len(self._data)
@property
def variance(self) -> float:
"""Calculate variance of collected samples."""
if len(self._data) <= 1:
return 1.0
mean_val = self.mean
variance_sum = sum((x - mean_val) ** 2 for x in self._data)
return variance_sum / (len(self._data) - 1)
[docs]
def is_complete(self) -> bool:
"""Check if the collection is complete."""
completed = len(self) >= self.max_samples
if completed:
self.active = False
return completed
[docs]
def append(self, value: float) -> None:
"""Add a new datum to the collection."""
if len(self) < self.max_samples:
self._data.append(value)
return
[docs]
def reset(self) -> None:
"""Reset the collection."""
self._data.clear()
self.active = True
return
[docs]
class HistClipper(ClipOperation):
"""Global gradient clipping strategy that uses previous gradient statistics.
The gradients' norm is renormalized according to a clipping criterion.
Attributes:
criterion: the clipping criterion to determine when and how to clip.
warmup_clip_strategy: the clipping strategy used during warmup.
n_warmup_steps: the number of warmup steps to collect initial stats.
"""
_default_criterion: ClassVar[ZStatCriterion] = ZStatCriterion()
_default_grad_op: ClassVar[GradNormClipper] = GradNormClipper()
criterion: ClippingCriterion
warmup_clip_strategy: GradientOpProtocol
n_warmup_steps: int
_warmup_handler: StatsCollector
def __init__(
self,
criterion: ClippingCriterion = _default_criterion,
warmup_clip_strategy: p.GradientOpProtocol = _default_grad_op,
n_warmup_steps: int = 20,
) -> None:
"""Initialize.
Args:
criterion: the clipping criterion to determine when and how to clip.
warmup_clip_strategy: the clipping strategy used during warmup.
n_warmup_steps: the number of warmup steps to collect initial stats.
"""
super().__init__()
self.criterion: Final = criterion
self.warmup_clip_strategy: Final = warmup_clip_strategy
self.n_warmup_steps: Final = n_warmup_steps
self._warmup_handler: Final = StatsCollector(n_warmup_steps)
return
[docs]
def __call__(self, params: Iterable[torch.nn.Parameter]) -> None:
"""Apply global gradient clipping.
Args:
params: model parameters to clip.
Side Effects:
Modifies gradients in-place if clipping is applied.
"""
# needed to allow multiple iterations
params_list = list(params)
squared_norms = [
(param.grad**2).sum().item()
for param in params_list
if param.grad is not None
]
global_norm = math.sqrt(sum(squared_norms))
if self._warmup_handler.active:
if not self._warmup_handler.is_complete():
self._warmup_handler.append(global_norm)
self.warmup_clip_strategy(params_list)
return
else:
self.criterion.set_statistics(
self._warmup_handler.mean, self._warmup_handler.variance
)
if self.criterion.should_clip(global_norm):
clip_value = self.criterion.get_clip_value(global_norm)
torch.nn.utils.clip_grad_norm_(params_list, clip_value)
self.criterion.update(global_norm)
return
[docs]
def reset(self):
"""Reset the state."""
self._warmup_handler.reset()
self.criterion.reset()
return
[docs]
class ParamHistClipper(ClipOperation):
"""Gradient clipping strategy that keeps per-parameter statistics.
The gradients' norm is renormalized according to a clipping criterion.
Attributes:
criterion: the clipping criterion to determine when and how to clip.
warmup_clip_strategy: the clipping strategy used during warmup.
n_warmup_steps: the number of warmup steps to collect initial stats.
"""
_default_criterion: ClassVar[ZStatCriterion] = ZStatCriterion()
_default_grad_op: ClassVar[GradNormClipper] = GradNormClipper()
criterion: ClippingCriterion
n_warmup_steps: int
warmup_clip_strategy: GradientOpProtocol
_dict_criterion: defaultdict[int, ClippingCriterion]
_dict_warmup_handler: defaultdict[int, StatsCollector]
def __init__(
self,
criterion: ClippingCriterion = _default_criterion,
warmup_clip_strategy: p.GradientOpProtocol = _default_grad_op,
n_warmup_steps: int = 20,
) -> None:
"""Initialize.
Args:
criterion: the clipping criterion to determine when and how to clip.
warmup_clip_strategy: the clipping strategy used during warmup.
n_warmup_steps: the number of warmup steps to collect initial stats.
"""
super().__init__()
self.criterion: Final = criterion
self.n_warmup_steps: Final = n_warmup_steps
self.warmup_clip_strategy: Final = warmup_clip_strategy
self._dict_criterion = defaultdict(lambda: copy.copy(criterion))
self._dict_warmup_handler = defaultdict(
lambda: StatsCollector(n_warmup_steps)
)
return
[docs]
def __call__(self, params: Iterable[torch.nn.Parameter]) -> None:
"""Apply global gradient clipping.
Args:
params: Model parameters to clip.
Side Effects:
Modifies gradients in-place if clipping is applied.
"""
for param in params:
grad = param.grad
if grad is None:
continue
grad_norm: float = grad.norm(2, dtype=float).item()
param_id = id(param)
warmup_handler = self._dict_warmup_handler[param_id]
criterion = self._dict_criterion[param_id]
if warmup_handler.active:
if not warmup_handler.is_complete():
warmup_handler.append(grad_norm)
self.warmup_clip_strategy([param])
continue
else:
criterion.set_statistics(
warmup_handler.mean, warmup_handler.variance
)
if criterion.should_clip(grad_norm):
clip_value = criterion.get_clip_value(grad_norm)
torch.nn.utils.clip_grad_norm_([param], clip_value)
criterion.update(grad_norm)
return
[docs]
def reset(self):
"""Reset the state."""
self._dict_criterion.clear()
self._dict_warmup_handler.clear()
return