"""Module containing generic accumulator-based aggregators."""
from __future__ import annotations
import abc
import copy
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Final, Generic, Self, TypeVar
import torch
from torch import distributed as dist
from typing_extensions import override
_T = TypeVar('_T')
_R = TypeVar('_R')
[docs]
class AbstractAccumulator(Generic[_T, _R], abc.ABC):
"""Stateful aggregation container."""
[docs]
@classmethod
@abc.abstractmethod
def from_value(cls, value: _T) -> Self:
"""Create accumulator from raw value."""
[docs]
@abc.abstractmethod
def merge(self, other: Self) -> None:
"""Merge another accumulator into this one."""
[docs]
@abc.abstractmethod
def reduce(self) -> _R:
"""Return reduced value."""
[docs]
@abc.abstractmethod
def sync(self) -> None:
"""Synchronize state across distributed processes."""
[docs]
class AbstractAggregator(Generic[_T, _R], metaclass=abc.ABCMeta):
"""Aggregate named values using accumulator objects."""
__slots__: Final = ('_cached_reduce', 'accumulators')
accumulator_cls: type[AbstractAccumulator[_T, _R]]
accumulators: dict[str, AbstractAccumulator[_T, _R]]
_cached_reduce: dict[str, _R]
def __init__(self, **kwargs: _T) -> None:
"""Initialize.
Args:
kwargs: named values to aggregate.
"""
self.accumulators = {}
self._cached_reduce = {}
for key, value in kwargs.items():
self.accumulators[key] = self.accumulator_cls.from_value(value)
return
[docs]
def __add__(self, other: Self | Mapping[str, _T]) -> Self:
"""Return new aggregator containing merged data."""
result = copy.deepcopy(self)
result += other
return result
[docs]
def __bool__(self) -> bool:
"""Return True if any values are stored."""
return bool(self.accumulators)
[docs]
def __iadd__(self, other: Self | Mapping[str, _T]) -> Self:
"""Merge another aggregator or mapping into this one."""
if isinstance(other, Mapping):
other = self.__class__(**other)
for key, acc in other.accumulators.items():
if key in self.accumulators:
self.accumulators[key].merge(acc)
else:
self.accumulators[key] = copy.deepcopy(acc)
self._cached_reduce.clear()
return self
[docs]
def __repr__(self) -> str:
"""Return representation of the aggregator."""
return (
f'{self.__class__.__name__}(keys={list(self.accumulators.keys())})'
)
[docs]
def clear(self) -> None:
"""Remove all accumulated data."""
self.accumulators.clear()
self._cached_reduce.clear()
[docs]
def keys(self) -> list[str]:
"""Return stored metric names."""
return list(self.accumulators.keys())
[docs]
def reduce(self) -> dict[str, _R]:
"""Return reduced values for all metrics."""
if not self._cached_reduce:
self._cached_reduce = {
key: acc.reduce() for key, acc in self.accumulators.items()
}
return self._cached_reduce
[docs]
def all_reduce(self) -> dict[str, _R]:
"""Synchronize accumulators across processes and reduce."""
for acc in self.accumulators.values():
acc.sync()
self._cached_reduce.clear()
return self.reduce()
[docs]
@dataclass(slots=True)
class MeanAccumulator(AbstractAccumulator[float, float]):
"""Accumulator computing arithmetic mean for floats.
Attributes:
total: sum of all values.
count: number of values.
"""
total: float
count: int
[docs]
@classmethod
@override
def from_value(cls, value: float) -> MeanAccumulator:
return cls(total=value, count=1)
[docs]
@override
def merge(self, other: Self) -> None:
self.total += other.total
self.count += other.count
[docs]
@override
def reduce(self) -> float:
return self.total / self.count
[docs]
@override
def sync(self) -> None:
return
[docs]
@dataclass(slots=True)
class TorchMeanAccumulator(AbstractAccumulator[torch.Tensor, torch.Tensor]):
"""Accumulator computing arithmetic mean for tensors.
Attributes:
total: sum of all values.
count: number of values.
"""
total: torch.Tensor
count: int
[docs]
@classmethod
@override
def from_value(cls, value: torch.Tensor) -> TorchMeanAccumulator:
return cls(
total=value.detach().sum(),
count=value.numel(),
)
[docs]
@override
def merge(self, other: Self) -> None:
self.total += other.total
self.count += other.count
[docs]
@override
def reduce(self) -> torch.Tensor:
return self.total / self.count
[docs]
@override
def sync(self) -> None:
if dist.is_available() and dist.is_initialized():
dist.all_reduce(self.total, op=dist.ReduceOp.SUM)
count_tensor = torch.tensor(
self.count,
device=self.total.device,
dtype=torch.long,
)
dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM)
self.count = int(count_tensor.item())
[docs]
class Averager(AbstractAggregator[float, float]):
"""Aggregator computing mean over floats."""
accumulator_cls = MeanAccumulator
[docs]
class TorchAverager(AbstractAggregator[torch.Tensor, torch.Tensor]):
"""Aggregator computing mean over tensors with distributed support."""
accumulator_cls = TorchMeanAccumulator