Source code for drytorch.utils.averages

"""Module containing classes and functions to average samples."""

from __future__ import annotations

import math

from collections.abc import Callable, Sequence


__all__ = [
    'get_moving_average',
    'get_trailing_mean',
]


[docs] def get_moving_average( decay: float = 0.9, mass_coverage: float = 0.99, ) -> Callable[[Sequence[float]], float]: """Return a moving average by specifying the decay. Args: decay: the closer to 0, the more the last elements have weight. mass_coverage: cumulative weight proportion before tail dropping. Returns: The moving average function. Raises: ValueError: if the decay is not between 0 and 1. ValueError: if the mass_coverage is not between 1 - decay and 1. """ if not 0 < decay < 1: raise ValueError('decay must be between 0 and 1.') if not 1 - decay <= mass_coverage <= 1: raise ValueError('mass_coverage should be between 1 - decay and 1.') # how far back to go back before the weight drops below the threshold if mass_coverage < 1: stop = -int(math.log(1 - mass_coverage, decay)) - 2 else: stop = None def _mean(float_list: Sequence[float], /) -> float: total: float = 0 total_weights: float = 0 # should get close to one weight = 1 - decay # weights are normalized for elem in float_list[:stop:-1]: total += weight * elem total_weights += weight weight *= decay return total / total_weights repr_mean = f'moving_average(decay={decay}, mass_coverage={mass_coverage})' _mean.__name__ = repr_mean return _mean
[docs] def get_trailing_mean(window_size: int) -> Callable[[Sequence[float]], float]: """Return a trailing average by specifying window size. Args: window_size: number of items to aggregate. Returns: The windowed average function. Raises: ValueError if the window size is negative. """ if window_size <= 0: raise ValueError('window_size must be positive.') def _mean(float_list: Sequence[float], /) -> float: clipped_window = min(window_size, len(float_list)) return sum(float_list[-clipped_window:]) / clipped_window repr_mean = f'trailing_mean(window_size={window_size})' _mean.__name__ = repr_mean return _mean