drytorch.lib.gradient_ops

Module containing gradient operations.

Functions

max_clipping(zt, z_thresh)

Standard clipping to the threshold value.

mean_clipping(zt, z_thresh)

Clip to the mean value (effectively setting gradient to running mean).

reciprocal_clipping(zt, z_thresh)

Reciprocal clipping as recommended in https://arxiv.org/pdf/2504.02507.

Classes

ClipOperation(*args, **kwargs)

Abstract base class for gradient operations.

ClippingCriterion()

Criteria that detects when to clip snd determines the clipping value.

EMACriterion(alpha, r_thresh, ...)

Clipping criterion based on Exponential Moving Average.

GradNormClipper([threshold])

Gradient norm clipping strategy.

GradParamNormalizer(*args, **kwargs)

Strategy that normalizes each parameter's gradient to unit norm.

GradValueClipper([threshold])

Gradient value clipping strategy.

GradZScoreNormalizer(*args, **kwargs)

Gradient normalizing strategy using Z-score normalization.

HistClipper(criterion, warmup_clip_strategy, ...)

Global gradient clipping strategy that uses previous gradient statistics.

NoOp()

Placeholder performing no gradient action.

ParamHistClipper(criterion, ...)

Gradient clipping strategy that keeps per-parameter statistics.

StatsCollector(max_samples)

Initialize.

ZStatCriterion(alpha, z_thresh, ...)

Clipping criterion based on the Z-statistic.

class ClippingCriterion[source]

Bases: ABC

Criteria that detects when to clip snd determines the clipping value.

abstractmethod should_clip(value: float) bool[source]

Determine whether to clip gradients based on the current value.

Parameters:

value (float) – current gradient norm or value to evaluate.

Returns:

True if gradients should be clipped, False otherwise.

Return type:

bool

abstractmethod get_clip_value(value: float) float[source]

Calculate the clipping threshold based on current statistics.

Parameters:

value (float) – Current gradient norm or value.

Returns:

The value to clip gradients to.

Return type:

float

update(value: float) None[source]

Update internal statistics with a new observed value.

Parameters:

value (float) – new gradient norm or value to incorporate.

Return type:

None

set_statistics(mean: float, variance: float = 0.0) None[source]

Initialize statistics from warmup data.

Parameters:
  • mean (float) – mean value from the warmup period.

  • variance (float) – variance from the warmup period (if applicable).

Return type:

None

reset() None[source]

Reset all internal statistics to initial state.

Return type:

None

class EMACriterion(alpha: float = 0.98, r_thresh: float = 1.05, clipping_function: ~collections.abc.Callable[[float, float], float] = <function max_clipping>)[source]

Bases: ClippingCriterion

Clipping criterion based on Exponential Moving Average.

It uses only the running mean of gradient norms to detect outliers. It clips when the current norm exceeds the mean by a factor of r_thresh.

alpha

exponential moving average decay factor.

Type:

float

r_thresh

ratio threshold between current_norm and mean_norm.

Type:

float

clipping_function

function to determine clipping behavior.

Type:

collections.abc.Callable[[float, float], float]

Initialize.

Parameters:
  • alpha (float) – exponential moving average decay factor.

  • r_thresh (float) – ratio threshold between current_norm and mean_norm.

  • clipping_function (Callable[[float, float], float]) – function to determine clipping behavior.

should_clip(value: float) bool[source]

Determine whether to clip gradients based on the current value.

Parameters:

value (float) – current gradient norm or value to evaluate.

Returns:

True if gradients should be clipped, False otherwise.

Return type:

bool

get_clip_value(value: float) float[source]

Calculate the clipping threshold based on current statistics.

Parameters:

value (float) – Current gradient norm or value.

Returns:

The value to clip gradients to.

Return type:

float

update(value: float) None[source]

Update internal statistics with a new observed value.

Parameters:

value (float) – new gradient norm or value to incorporate.

Return type:

None

set_statistics(mean: float, variance: float = 0.0) None[source]

Initialize statistics from warmup data.

Parameters:
  • mean (float) – mean value from the warmup period.

  • variance (float) – variance from the warmup period (if applicable).

Return type:

None

reset() None[source]

Reset all internal statistics to initial state.

Return type:

None

class GradNormClipper(threshold: float = 1)[source]

Bases: ClipOperation

Gradient norm clipping strategy.

threshold

Maximum norm value of the clipped gradients.

Type:

float

Initialize.

Parameters:

threshold (float) – Maximum norm value of the clipped gradients.

__call__(params: Iterable[Parameter]) None[source]

Clip gradients by norm in-place.

Parameters:

params (Iterable[Parameter])

Return type:

None

class GradParamNormalizer(*args, **kwargs)[source]

Bases: GradientOpProtocol

Strategy that normalizes each parameter’s gradient to unit norm.

__call__(params: Iterable[Parameter]) None[source]

Normalize gradients to unit norm in-place.

Parameters:

params (Iterable[Parameter])

Return type:

None

class GradValueClipper(threshold: float = 1)[source]

Bases: ClipOperation

Gradient value clipping strategy.

threshold

Maximum absolute value of the clipped gradients.

Type:

float

Initialize.

Parameters:

threshold (float) – Maximum absolute value of the clipped gradients.

__call__(params: Iterable[Parameter]) None[source]

Clip gradients by value in-place.

Parameters:

params (Iterable[Parameter])

Return type:

None

class GradZScoreNormalizer(*args, **kwargs)[source]

Bases: GradientOpProtocol

Gradient normalizing strategy using Z-score normalization.

__call__(params: Iterable[Parameter]) None[source]

Normalize gradients using Z-score in-place.

Parameters:

params (Iterable[Parameter])

Return type:

None

class HistClipper(criterion: ClippingCriterion = <drytorch.lib.gradient_ops.ZStatCriterion object>, warmup_clip_strategy: GradientOpProtocol = <drytorch.lib.gradient_ops.GradNormClipper object>, n_warmup_steps: int = 20)[source]

Bases: ClipOperation

Global gradient clipping strategy that uses previous gradient statistics.

The gradients’ norm is renormalized according to a clipping criterion.

criterion

the clipping criterion to determine when and how to clip.

Type:

drytorch.lib.gradient_ops.ClippingCriterion

warmup_clip_strategy

the clipping strategy used during warmup.

Type:

drytorch.core.protocols.GradientOpProtocol

n_warmup_steps

the number of warmup steps to collect initial stats.

Type:

int

Initialize.

Parameters:
  • criterion (ClippingCriterion) – the clipping criterion to determine when and how to clip.

  • warmup_clip_strategy (GradientOpProtocol) – the clipping strategy used during warmup.

  • n_warmup_steps (int) – the number of warmup steps to collect initial stats.

__call__(params: Iterable[Parameter]) None[source]

Apply global gradient clipping.

Parameters:

params (Iterable[Parameter]) – model parameters to clip.

Return type:

None

Side Effects:

Modifies gradients in-place if clipping is applied.

reset()[source]

Reset the state.

class NoOp[source]

Bases: GradientOpProtocol

Placeholder performing no gradient action.

__call__(params: Iterable[Parameter]) None[source]

No operation is performed.

Parameters:

params (Iterable[Parameter])

Return type:

None

class ParamHistClipper(criterion: ClippingCriterion = <drytorch.lib.gradient_ops.ZStatCriterion object>, warmup_clip_strategy: GradientOpProtocol = <drytorch.lib.gradient_ops.GradNormClipper object>, n_warmup_steps: int = 20)[source]

Bases: ClipOperation

Gradient clipping strategy that keeps per-parameter statistics.

The gradients’ norm is renormalized according to a clipping criterion.

criterion

the clipping criterion to determine when and how to clip.

Type:

drytorch.lib.gradient_ops.ClippingCriterion

warmup_clip_strategy

the clipping strategy used during warmup.

Type:

drytorch.core.protocols.GradientOpProtocol

n_warmup_steps

the number of warmup steps to collect initial stats.

Type:

int

Initialize.

Parameters:
  • criterion (ClippingCriterion) – the clipping criterion to determine when and how to clip.

  • warmup_clip_strategy (GradientOpProtocol) – the clipping strategy used during warmup.

  • n_warmup_steps (int) – the number of warmup steps to collect initial stats.

__call__(params: Iterable[Parameter]) None[source]

Apply global gradient clipping.

Parameters:

params (Iterable[Parameter]) – Model parameters to clip.

Return type:

None

Side Effects:

Modifies gradients in-place if clipping is applied.

reset()[source]

Reset the state.

class StatsCollector(max_samples: int)[source]

Bases: object

Initialize.

max_samples

the number of collected samples for completion.

Type:

int

active

whether the collector is currently in use.

Type:

bool

Initialize warmup handler.

Parameters:

max_samples (int) – the number of collected samples for completion.

__len__() int[source]

Return the number of collected warmup samples.

Return type:

int

property mean: float

Calculate mean of collected warmup samples.

property variance: float

Calculate variance of collected samples.

is_complete() bool[source]

Check if the collection is complete.

Return type:

bool

append(value: float) None[source]

Add a new datum to the collection.

Parameters:

value (float)

Return type:

None

reset() None[source]

Reset the collection.

Return type:

None

class ZStatCriterion(alpha: float = 0.97, z_thresh: float = 2.5, clipping_function: ~collections.abc.Callable[[float, float], float] = <function reciprocal_clipping>)[source]

Bases: ClippingCriterion

Clipping criterion based on the Z-statistic.

Tracks both mean and variance using exponential moving averages. The clipping threshold is on the Z-score (standardized deviation). See also https://arxiv.org/pdf/2504.02507.

alpha

exponential moving average decay factor (0 < alpha < 1).

Type:

float

z_thresh

Z-score threshold between !z_score| and z_thresh.

Type:

float

clipping_function

function to determine clipping behavior.

Type:

collections.abc.Callable[[float, float], float]

Initialize.

Parameters:
  • alpha (float) – exponential moving average decay factor (0 < alpha < 1).

  • z_thresh (float) – threshold for the Z-score.

  • clipping_function (Callable[[float, float], float]) – function to determine clipping behavior.

should_clip(value: float) bool[source]

Check if the Z-score exceeds the threshold.

Parameters:

value (float)

Return type:

bool

get_clip_value(value: float) float[source]

Calculate the clipping threshold based on current statistics.

Parameters:

value (float) – Current gradient norm or value.

Returns:

The value to clip gradients to.

Return type:

float

update(value: float) None[source]

Update internal statistics with a new observed value.

Parameters:

value (float) – new gradient norm or value to incorporate.

Return type:

None

set_statistics(mean: float, variance: float = 0.0) None[source]

Initialize statistics from warmup data.

Parameters:
  • mean (float) – mean value from the warmup period.

  • variance (float) – variance from the warmup period (if applicable).

Return type:

None

reset() None[source]

Reset all internal statistics to initial state.

Return type:

None

max_clipping(zt: float, z_thresh: float) float[source]

Standard clipping to the threshold value.

Parameters:
  • zt (float) – the Z-statistic or ratio of the current gradient norm.

  • z_thresh (float) – the threshold for the z-statistic values.

Returns:

The threshold value as the renormalization factor.

Return type:

float

mean_clipping(zt: float, z_thresh: float) float[source]

Clip to the mean value (effectively setting gradient to running mean).

Parameters:
  • zt (float) – the Z-statistic or ratio of the current gradient norm.

  • z_thresh (float) – the threshold for the z-statistic values.

Returns:

Renormalization factor of 0 (clips to mean).

Return type:

float

reciprocal_clipping(zt: float, z_thresh: float) float[source]

Reciprocal clipping as recommended in https://arxiv.org/pdf/2504.02507.

Instead of clipping to the threshold value, reciprocal clipping decreases the norm of the gradient even further as the spike gets larger.

Parameters:
  • zt (float) – the Z-statistic or ratio of the current gradient norm.

  • z_thresh (float) – the threshold for the z-statistic values.

Returns:

Renormalization factor (between 0 and 1).

Return type:

float