drytorch.contrib.swa_utils

Utilities for Stochastic Weight Averaging.

Classes

BatchNormUpdater(model[, name])

Update the momenta in the batch normalization layers.

class BatchNormUpdater(model: ModelProtocol[Input, Output], name: str = '', *, loader: LoaderProtocol[tuple[Input, Target]])[source]

Bases: ModelRunner[Input, Target, Output]

Update the momenta in the batch normalization layers.

Initialize.

Parameters:
  • model (ModelProtocol[Input, Output]) – the model to run.

  • name (str) – the name for the object for logging purposes. Defaults to class name plus eventual counter.

  • loader (LoaderProtocol[tuple[Input, Target]]) – provides inputs and targets in batches.

__call__(store_outputs: bool = False) None[source]

Single pass on the dataset.

Parameters:

store_outputs (bool)

Return type:

None