drytorch.contrib.swa_utils
Utilities for Stochastic Weight Averaging.
Classes
|
Update the momenta in the batch normalization layers. |
- class BatchNormUpdater(model: ModelProtocol[Input, Output], name: str = '', *, loader: LoaderProtocol[tuple[Input, Target]])[source]
Bases:
ModelRunner[Input,Target,Output]Update the momenta in the batch normalization layers.
Initialize.
- Parameters:
model (ModelProtocol[Input, Output]) – the model to run.
name (str) – the name for the object for logging purposes. Defaults to class name plus eventual counter.
loader (LoaderProtocol[tuple[Input, Target]]) – provides inputs and targets in batches.