Source code for drytorch.utils.average

"""Module containing classes and functions to average samples."""

from __future__ import annotations

import abc
import collections
import copy
import math

from collections.abc import Callable, KeysView, Mapping, Sequence
from typing import Any, Final, Generic, Self, TypeVar

import torch

from torch import distributed as dist
from typing_extensions import override


__all__ = [
    'TorchAverager',
    'get_moving_average',
    'get_trailing_mean',
]

_T = TypeVar('_T', torch.Tensor, float)


class AbstractAverager(Generic[_T], metaclass=abc.ABCMeta):
    """Average tensor values from dict-like objects.

    It registers sample size to calculate the precise sample average.

    Attributes:
        aggregate: a dictionary with the aggregated values.
        counts: a dictionary with the count of the total elements.
    """

    __slots__: Final = ('_cached_reduce', 'aggregate', 'counts')

    aggregate: dict[str, _T]
    counts: collections.defaultdict[str, int]
    _cached_reduce: dict[str, _T]

    def __init__(self, **kwargs: _T):
        """Initialize.

        Args:
            kwargs: named values to average.
        """
        self.aggregate = {}
        self.counts = collections.defaultdict(int)
        self.__iadd__(kwargs)
        self._cached_reduce = {}

    def __add__(self, other: AbstractAverager[_T] | Mapping[str, _T]) -> Self:
        """Join current data with data from another Averager.

        Args:
            other: the other Averager.
        """
        if isinstance(other, Mapping):
            other = self.__class__(**other)

        out = copy.deepcopy(self)
        out += other
        return out

    def __bool__(self) -> bool:
        """Return True if data is present."""
        return bool(self.aggregate)

    def __deepcopy__(self, memo: dict[int, Any] | None) -> Self:
        """Deep copy magic method.

        Args:
            memo: Dictionary of already copied objects.

        Returns:
            A deep copy of the object.
        """
        cls = self.__class__
        result = cls.__new__(cls)
        result.aggregate = copy.deepcopy(self.aggregate)
        result.counts = copy.copy(self.counts)
        result._cached_reduce = {}
        for k, v in self.__dict__.items():
            setattr(result, k, copy.deepcopy(v, memo))

        return result

    @override
    def __eq__(self, other: object) -> bool:
        if not isinstance(other, self.__class__):
            return False
        return self.aggregate == other.aggregate and self.counts == other.counts

    def __iadd__(self, other: AbstractAverager[_T] | Mapping[str, _T]) -> Self:
        """Merge current data with data from another Averager.

        Args:
            other: the other Averager.
        """
        if isinstance(other, AbstractAverager):
            other_aggregate = other.aggregate
            other_counts = other.counts
        else:
            other_aggregate = {}
            other_counts = collections.defaultdict[str, int]()
            for key, value in other.items():
                other_aggregate[key] = self._aggregate(value)
                other_counts[key] = self._count(value)

        if self.aggregate:  # fail if new elements are added after the start
            for key, value in other_aggregate.items():
                self.aggregate[key] += value
                self.counts[key] += other_counts[key]
        else:
            self.aggregate.update(other_aggregate)
            self.counts.update(other_counts)

        self._cached_reduce = {}
        return self

    @override
    def __repr__(self) -> str:
        return self.__class__.__name__ + f'(counts={self.counts})'

    def clear(self) -> None:
        """Clear data contained in the class."""
        self._cached_reduce.clear()
        self.aggregate.clear()
        self.counts.clear()
        return

    def keys(self) -> KeysView[str]:
        """Calculate the count of a value."""
        return self.aggregate.keys()

    def reduce(self) -> dict[str, _T]:
        """Return the averaged values."""
        if not self._cached_reduce:
            self._cached_reduce = {
                key: self._reduce(value, self.counts[key])
                for key, value in self.aggregate.items()
            }

        return self._cached_reduce

    def all_reduce(self) -> dict[str, _T]:
        """Synchronize the averaged values across processes."""
        self._cached_reduce.clear()  # make sure to recalculate all the values
        return self.reduce()

    @staticmethod
    def _reduce(aggregated: _T, count: int) -> _T:
        return aggregated / count

    @staticmethod
    @abc.abstractmethod
    def _aggregate(value: _T) -> _T: ...

    @staticmethod
    @abc.abstractmethod
    def _count(value: _T) -> int: ...


class Averager(AbstractAverager[float]):
    """Subclass of Aggregator with an implementation for float values."""

    @staticmethod
    @override
    def _aggregate(value: float) -> float:
        return value

    @staticmethod
    @override
    def _count(value: float) -> int:
        return 1


[docs] class TorchAverager(AbstractAverager[torch.Tensor]): """Subclass of Aggregator with an implementation for torch.Tensor.""" @staticmethod @override def _aggregate(value: torch.Tensor) -> torch.Tensor: return value.detach().sum() @staticmethod @override def _count(value: torch.Tensor) -> int: return value.numel()
[docs] @override def all_reduce(self) -> dict[str, torch.Tensor]: """Synchronize the values across processes.""" if dist.is_available() and torch.distributed.is_initialized(): for key, value in self.aggregate.items(): torch.distributed.all_reduce( value, op=torch.distributed.ReduceOp.SUM ) count_tensor = torch.tensor( self.counts[key], device=value.device, dtype=torch.long ) torch.distributed.all_reduce( count_tensor, op=torch.distributed.ReduceOp.SUM ) self.counts[key] = int(count_tensor.item()) return super().all_reduce()
[docs] def get_moving_average( decay: float = 0.9, mass_coverage: float = 0.99, ) -> Callable[[Sequence[float]], float]: """Return a moving average by specifying the decay. Args: decay: the closer to 0, the more the last elements have weight. mass_coverage: cumulative weight proportion before tail dropping. Returns: The moving average function. Raises: ValueError: if the decay is not between 0 and 1. ValueError: if the mass_coverage is not between 1 - decay and 1. """ if not 0 < decay < 1: raise ValueError('decay must be between 0 and 1.') if not 1 - decay <= mass_coverage <= 1: raise ValueError('mass_coverage should be between 1 - decay and 1.') # how far back to go back before the weight drops below the threshold if mass_coverage < 1: stop = -int(math.log(1 - mass_coverage, decay)) - 2 else: stop = None def _mean(float_list: Sequence[float], /) -> float: total: float = 0 total_weights: float = 0 # should get close to one weight = 1 - decay # weights are normalized for elem in float_list[:stop:-1]: total += weight * elem total_weights += weight weight *= decay return total / total_weights repr_mean = f'moving_average(decay={decay}, mass_coverage={mass_coverage})' _mean.__name__ = repr_mean return _mean
[docs] def get_trailing_mean(window_size: int) -> Callable[[Sequence[float]], float]: """Return a trailing average by specifying window size. Args: window_size: number of items to aggregate. Returns: The windowed average function. Raises: ValueError if the window size is negative. """ if window_size <= 0: raise ValueError('window_size must be positive.') def _mean(float_list: Sequence[float], /) -> float: clipped_window = min(window_size, len(float_list)) return sum(float_list[-clipped_window:]) / clipped_window repr_mean = f'trailing_mean(window_size={window_size})' _mean.__name__ = repr_mean return _mean