Source code for drytorch.lib.objectives

"""Module containing classes to create and combine loss and metrics.

The interface is similar to https://github.com/Lightning-AI/torchmetrics,
with stricter typing and simpler construction. MetricCollection and
CompositionalMetric from torchmetrics change their state; here a functional
approach is preferred.
"""

from __future__ import annotations

import abc
import copy
import operator
import warnings

from collections.abc import Callable, Mapping, Sequence
from typing import Any, Final, Generic, Literal, Self, TypeVar

import torch

from typing_extensions import override

from drytorch.core import exceptions
from drytorch.core import protocols as p
from drytorch.lib import aggregators


__all__ = [
    'AverageObjective',
    'CompositionalLoss',
    'JoinLossMetrics',
    'JoinMetrics',
    'Loss',
    'LossBase',
    'Metric',
    'MetricCollection',
    'MetricTracker',
    'Objective',
    'compute_metrics',
]


Output = TypeVar('Output', bound=p.OutputType, contravariant=True)
Target = TypeVar('Target', bound=p.TargetType, contravariant=True)
Tensor = torch.Tensor


[docs] class Objective(p.ObjectiveProtocol[Output, Target], metaclass=abc.ABCMeta): """Abstract base class for metrics or losses.""" def __init__(self) -> None: """Initialize.""" self._aggregator = self._get_aggregator() return
[docs] @override def compute(self: Self) -> dict[str, Tensor]: """Return the aggregated objective value(s). Despite the name, which follows common practice, this method caches previous computed values and returns them if available. Returns: A dictionary of computed metric values. """ if not self._aggregator: warnings.warn( exceptions.ComputedBeforeUpdatedWarning(self), stacklevel=1 ) return self._compute()
[docs] @override def update( self: Self, outputs: Output, targets: Target ) -> dict[str, Tensor]: """Updates the objective's internal state with new outputs and targets. Args: outputs: the model outputs. targets: the ground truth targets. Returns: A dictionary of the calculated metric values for the current update. """ results = self.calculate(outputs, targets) self._aggregator += results return results
[docs] @override def reset(self: Self) -> None: """Resets the internal state of the instance.""" self._aggregator.clear() return
[docs] @abc.abstractmethod def calculate( self: Self, outputs: Output, targets: Target ) -> dict[str, Tensor]: """Method responsible for the calculations. Args: outputs: model outputs. targets: ground truth. Returns: A dictionary of calculated metric values. """
[docs] def copy(self) -> Self: """Create a (deep)copy of self.""" return copy.deepcopy(self, {})
[docs] def merge_state(self: Self, other: Self) -> None: """Merge metric states. Args: other: metric to be merged with. """ self._aggregator += other._aggregator return
[docs] def sync(self: Self) -> None: """Synchronize metric states across processes.""" self._aggregator.all_reduce() return
[docs] def __deepcopy__(self, memo: dict[int, Any]) -> Self: """Deep copy magic method. Args: memo: dictionary of already copied objects. Returns: A deep copy of the object. """ cls = self.__class__ result = cls.__new__(cls) for k, v in self.__dict__.items(): setattr(result, k, copy.deepcopy(v, memo)) return result
@abc.abstractmethod def _compute(self: Self) -> dict[str, Tensor]: """Computes the objective value(s).""" @classmethod @abc.abstractmethod def _get_aggregator(cls) -> aggregators.AbstractAggregator[Any, Any]: """Returns the aggregator class."""
[docs] class AverageObjective(Objective[Output, Target], metaclass=abc.ABCMeta): """Class defining the default aggregation.""" @override def _compute(self: Self) -> dict[str, Tensor]: """Computes the objective value(s).""" return self._aggregator.reduce() @classmethod def _get_aggregator(cls) -> aggregators.AbstractAggregator[Tensor, Tensor]: return aggregators.TorchAverager()
[docs] class MetricCollection(AverageObjective[Output, Target]): """A collection of multiple metrics. Attributes: named_fn: dictionary of named functions to calculate. """ named_fn: dict[str, Callable[[Output, Target], Tensor]] def __init__( self, **named_fn: Callable[[Output, Target], Tensor], ) -> None: """Initialize. Args: **named_fn: dictionary of named functions to calculate. """ super().__init__() self.named_fn: Final = named_fn return
[docs] @override def calculate(self, outputs: Output, targets: Target) -> dict[str, Tensor]: """Calculates the values for all metrics in the collection. Args: outputs: the model outputs. targets: the ground truth targets. Returns: A dictionary of calculated metric values. """ return dict_apply(self.named_fn, outputs, targets)
[docs] def __or__( self, other: MetricCollection[Output, Target] ) -> MetricCollection[Output, Target]: """Constructor using existing MetricCollection objects as templates. This class does not aggregate the states. If you intend to do this, use the merge_state method separately. Args: other: another MetricCollection object to combine with. Returns: A new instance containing metrics from both instances. """ named_fn = self.named_fn | other.named_fn return MetricCollection(**named_fn)
[docs] class Metric(MetricCollection[Output, Target]): """Subclass for a single metr. Attributes: fun: the callable that computes the metric value. name: identifier for the metric. higher_is_better: True if higher values indicate better performance. """ fun: Callable[[Output, Target], Tensor] name: str higher_is_better: bool | None def __init__( self, fn: Callable[[Output, Target], Tensor], /, name: str, higher_is_better: bool | None = None, ) -> None: """Initialize. Args: fn: the callable that computes the metric value. name: identifier for the metric. higher_is_better: True if higher values indicate better performance, False if lower values are better, None if unspecified. """ super().__init__(**{name: fn}) self.fun: Final = fn self.name: Final = name self.higher_is_better: Final = higher_is_better return
[docs] class LossBase( MetricCollection[Output, Target], p.LossProtocol[Output, Target], metaclass=abc.ABCMeta, ): """Collection of metrics, one of which serves as a loss. Attributes: name: identifier for the loss. higher_is_better: True if higher values indicate better performance. formula: string representation of the loss formula. criterion: logic extracting a loss value from computed value. """ name: str higher_is_better: bool formula: str criterion: Callable[[dict[str, Tensor]], Tensor] def __init__( self, criterion: Callable[[dict[str, Tensor]], Tensor], name: str, higher_is_better: bool = False, formula: str = '', **named_fn: Callable[[Output, Target], Tensor], ) -> None: """Initialize. Args: criterion: logic extracting a loss value from computed value. name: identifier for the loss. higher_is_better: True if higher values indicate better performance, False if lower values are better. formula: string representation of the loss formula. **named_fn: dictionary of named functions to calculate. """ self.name: Final = name self.higher_is_better: Final = higher_is_better self.formula: Final = formula super().__init__(**named_fn) self.criterion: Final = criterion return
[docs] @override def forward(self, outputs: Output, targets: Target) -> Tensor: """Performs a forward pass, updates metrics, and computes the loss. Args: outputs: the model outputs. targets: the ground truth targets. Returns: The computed loss value. """ metrics = self.update(outputs, targets) return self.criterion(metrics).mean()
[docs] def watch(self, metric: MetricCollection[Output, Target]) -> None: """Include another Objective class in its metrics. Args: metric: the other Objective to watch. """ self.named_fn.update(metric.named_fn) return
def _combine( self, other: LossBase[Output, Target] | float, operation: Callable[[Tensor, Tensor], Tensor], op_fmt: str, requires_parentheses: bool = True, ) -> CompositionalLoss[Output, Target]: """Support operations between losses or a loss and a float. Args: other: the other loss or float to combine with. operation: the callable operation to apply (e.g., operator.add). op_fmt: the format string for the combined formula. requires_parentheses: whether to wrap sub-formulas in parentheses. Returns: A new CompositionalLoss representing the combined loss. """ if isinstance(other, LossBase): named_fn = self.named_fn | other.named_fn str_first = self.formula str_second = other.formula # apply should combine losses that share the same direction self._check_same_direction(other) def _combined(x: dict[str, Tensor]) -> Tensor: return operation(self.criterion(x), other.criterion(x)) elif isinstance(other, float | int): named_fn = self.named_fn str_first = str(other) str_second = self.formula def _combined(x: dict[str, Tensor]) -> Tensor: return operation(self.criterion(x), torch.tensor(other)) else: raise TypeError(f'Unsupported type for operation: {type(other)}') if not requires_parentheses: str_first = self._remove_outer_parentheses(str_first) str_second = self._remove_outer_parentheses(str_second) formula = op_fmt.format(str_first, str_second) return CompositionalLoss( criterion=_combined, higher_is_better=self.higher_is_better, name='Combined Loss', formula=formula, **named_fn, )
[docs] def __neg__(self) -> CompositionalLoss[Output, Target]: """Constructor from an existing template. Returns: A new CompositionalLoss representing the negated loss. """ return CompositionalLoss( criterion=lambda x: -self.criterion(x), higher_is_better=not self.higher_is_better, name='Negative ' + self.name, formula=f'-{self.formula}', **self.named_fn, )
[docs] def __add__( self, other: LossBase[Output, Target] | float, ) -> CompositionalLoss[Any, Any]: """Constructor from exiting templates. Args: other: the other loss or float to add. Returns: A new CompositionalLoss representing the sum. """ if other == 0 and isinstance(self, CompositionalLoss): return self return self._combine(other, operator.add, '{} + {}', False)
[docs] def __radd__(self, other: float) -> CompositionalLoss[Any, Any]: """Constructor from exiting templates. Args: other: the float to add to the loss. Returns: A new CompositionalLoss representing the sum. """ return self.__add__(other)
[docs] def __sub__( self, other: LossBase[Output, Target] | float, ) -> CompositionalLoss[Output, Target]: """Constructor from exiting templates. Args: other: the other loss or float to subtract. Returns: A new CompositionalLoss representing the difference. """ neg_other = other.__neg__() return self.__add__(neg_other)
[docs] def __rsub__(self, other: float) -> CompositionalLoss[Output, Target]: """Constructor from exiting templates. Args: other: the float from which to subtract the loss. Returns: A new CompositionalLoss representing the difference. """ neg_self = self.__neg__() return neg_self.__add__(other)
[docs] def __mul__( self, other: LossBase[Output, Target] | float, ) -> CompositionalLoss[Output, Target]: """Constructor from exiting templates. Args: other: the other loss or float to multiply by. Returns: A new CompositionalLoss representing the product. """ if other == 1 and isinstance(self, CompositionalLoss): return self return self._combine(other, operator.mul, '{} x {}')
[docs] def __rmul__(self, other: float) -> CompositionalLoss[Output, Target]: """Constructor from exiting templates. Args: other: the float to multiply the loss by. Returns: A new CompositionalLoss representing the product. """ return self.__mul__(other)
[docs] def __truediv__( self, other: LossBase[Output, Target] | float, ) -> CompositionalLoss[Output, Target]: """Constructor from exiting templates. Args: other: the other loss or float to divide by. Returns: A new CompositionalLoss representing the quotient. """ if other == 1 and isinstance(self, CompositionalLoss): return self mul_inv_other = other.__pow__(-1) return self.__mul__(mul_inv_other)
[docs] def __rtruediv__(self, other: float) -> CompositionalLoss[Output, Target]: """Constructor from exiting templates. Args: other: the float to be divided by the loss. Returns: A new CompositionalLoss representing the quotient. """ mul_inv_self = self.__pow__(-1) return mul_inv_self.__mul__(other)
[docs] def __pow__(self, other: float) -> CompositionalLoss[Output, Target]: """Constructor from exiting templates. Args: other: the power to raise the loss to. Returns: A new CompositionalLoss representing the result. """ def _to_floating_point(x: Tensor) -> Tensor: return x if torch.is_floating_point(x) else x.float() if other == 1 and isinstance(self, CompositionalLoss): return self elif other == -1: higher_is_better = not self.higher_is_better formula = f'1 / {self.formula}' elif other >= 0: higher_is_better = self.higher_is_better formula = f'{self.formula}^{other}' else: higher_is_better = not self.higher_is_better formula = f'1 / {self.formula}^{-other}' return CompositionalLoss( criterion=lambda x: _to_floating_point(self.criterion(x)) ** other, higher_is_better=higher_is_better, name='Loss', formula=formula, **self.named_fn, )
[docs] def __repr__(self): """Returns the string representation of the LossBase object.""" return f'{self.__class__.__name__}({self.formula})'
def _check_same_direction(self, other: LossBase[Output, Target]) -> None: """Checks if two losses have the same optimization direction. Args: other: the other LossBase object to compare with. Raises: ValueError: If the losses have opposite directions for optimization. """ if self.higher_is_better ^ other.higher_is_better: msg = 'Losses {} and {} have opposite directions for optimizations.' raise ValueError(msg.format(self, other)) return @staticmethod def _remove_outer_parentheses(formula: str) -> str: """Removes outer parentheses from a formula string if present. Args: formula: the formula string. Returns: The formula string without outer parentheses. """ if formula.startswith('(') and formula.endswith(')'): return formula[1:-1] if formula.startswith('[]') and formula.endswith(']'): return formula[1:-1] return formula
[docs] class CompositionalLoss( LossBase[Output, Target], ): """Loss resulting from an operation between other two losses.""" def __init__( self, criterion: Callable[[dict[str, Tensor]], Tensor], *, name='Loss', higher_is_better: bool, formula: str = '', **named_fn: Callable[[Output, Target], Tensor], ) -> None: """Initialize. Args: criterion: function extracting a loss value from metric functions. name: identifier for the loss. higher_is_better: True if higher values indicate better performance, False if lower values are better. formula: string representation of the loss formula. named_fn: dictionary of named metric functions. """ super().__init__( criterion, name, higher_is_better, **named_fn, formula=self._format_formula(formula), ) return
[docs] @override def calculate( self: Self, outputs: Output, targets: Target ) -> dict[str, Tensor]: """Calculates the loss and all associated metric values. Args: outputs: the model outputs. targets: the ground truth targets. Returns: A dictionary containing the calculated loss and metric values. """ all_metrics = super().calculate(outputs, targets) return {self.name: self.criterion(all_metrics)} | all_metrics
@staticmethod def _format_formula(formula: str) -> str: """Simplifies the formula string by removing redundant characters. Args: formula: the formula string. Returns: The simplified formula string. """ formula = formula.replace('--', '').replace('+ -', '- ') if formula.startswith('(') and formula.endswith(')'): return formula if formula.startswith('[') and formula.endswith(']'): return formula return '(' + formula + ')'
[docs] class Loss(CompositionalLoss[Output, Target]): """Subclass for simple losses with a convenient constructor.""" def __init__( self, fn: Callable[[Output, Target], Tensor], /, name: str, higher_is_better: bool = False, ): """Initialize. Args: fn: the callable to calculate the loss. name: the name for the loss. higher_is_better: the direction for optimization. """ super().__init__( operator.itemgetter(name), name=name, higher_is_better=higher_is_better, formula=f'[{name}]', **{name: fn}, ) return
[docs] class JoinMetrics(p.ObjectiveProtocol[Output, Target]): """Wrapper that joins two metrics. Preferably, use :py:meth:`MetricCollection.__or__` when both classes are :py:meth:`MetricCollection`. """ def __init__( self, metric_a: Objective[Output, Target], metric_b: Objective[Output, Target], /, ) -> None: """Initialize. Args: metric_a: first objective. metric_b: second objective. """ super().__init__() self.metric_a = metric_a self.metric_b = metric_b return
[docs] @override def compute(self: Self) -> dict[str, torch.Tensor]: """Return the aggregated values of both metrics. Returns: A dictionary merging the computed values of both metrics. """ return self.metric_a.compute() | self.metric_b.compute()
[docs] @override def reset(self) -> None: """Reset the internal state of both metrics.""" self.metric_a.reset() self.metric_b.reset() return
[docs] def sync(self: Self) -> None: """Synchronize metric states across processes.""" self.metric_a.sync() self.metric_b.sync() return
[docs] @override def update( self: Self, outputs: Output, targets: Target ) -> dict[str, torch.Tensor]: """Update both metrics with new outputs and targets. Args: outputs: the model outputs. targets: the ground truth targets. Returns: A dictionary merging the calculated values of both metrics. """ metric_a_update = self.metric_a.update(outputs, targets) metric_b_update = self.metric_b.update(outputs, targets) return metric_a_update | metric_b_update
[docs] class JoinLossMetrics(JoinMetrics, p.LossProtocol[Output, Target]): """Loss resulting from adding an extra metric to a loss. Preferably, use :py:meth:`LossBase.watch` when the metric is a :py:meth:`MetricCollection`. """ def __init__( self, loss: LossBase[Output, Target], objective: Objective[Output, Target], /, ) -> None: """Initialize. Args: loss: the primary loss. objective: the extra metric to track alongside the loss. """ super().__init__(loss, objective) self.loss = loss self.name = self.loss.name return
[docs] @override def forward(self, outputs: Output, targets: Target, /) -> torch.Tensor: """Compute and return the loss. Args: outputs: the model outputs. targets: the ground truth targets. Returns: The computed loss. """ return self.loss.forward(outputs, targets)
[docs] class MetricTracker(Generic[Output, Target]): """Handle metric value tracking and improvement detection. This class is responsible for storing metric history, determining improvements, and managing patience countdown. Note: this class can be used to automatically modify the training strategy. Therefore, it does not follow the library conventions for a tracker. Attributes: metric_name: the name of the metric to monitor. min_delta: the minimum change required to qualify as an improvement. patience: number of checks to wait before triggering callback. best_is: whether higher or lower values are better. filter_fn: function to aggregate recent metric values. history: logs of the recorded metrics. """ metric_name: str | None best_is: Literal['auto', 'higher', 'lower'] filter_fn: Callable[[Sequence[float]], float] min_delta: float patience: int history: list[float] _patience_countdown: int _best_value: float | None def __init__( self, metric_name: str | None = None, min_delta: float = 1e-8, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), ) -> None: """Initialize. Args: metric_name: name of the metric to track. min_delta: minimum change required to qualify as an improvement. patience: number of checks to wait before triggering callback. best_is: whether higher or lower metric values are better. filter_fn: function to aggregate recent metric values. """ self.metric_name = metric_name self.best_is = best_is self.filter_fn: Final = filter_fn self.min_delta: Final = min_delta self._validate_patience(patience) self.patience: Final = patience self.history: Final = list[float]() self._patience_countdown = patience self._best_value = None return @property def best_value(self) -> float: """Get the best result observed so far. Returns: the best filtered value (according to the 'best_is' criterion). Raises: ResultNotAvailableError: if no results have been logged yet. """ if self._best_value is None: try: self._best_value = self.history[0] except IndexError as ie: raise exceptions.ResultNotAvailableError() from ie return self._best_value @best_value.setter def best_value(self, value: float) -> None: """Set the best result value.""" self._best_value = value return @property def filtered_value(self) -> float: """Get the current value. Returns: the current value aggregated from recent ones. Raises: ResultNotAvailableError: if no results have been logged yet. """ return self.filter_fn(self.history)
[docs] def add_value(self, value: float) -> None: """Add a new metric value to the history. Args: value: the metric value to add. """ self.history.append(value) return
[docs] def is_better(self, value: float, reference: float) -> bool: """Determine if the value is better than a reference value. When best_is is in 'auto' mode, it is assumed that the given value is better than the first recorded one. Args: value: the value to compare. reference: the reference. Returns: True if value is a potential improvement, False otherwise. """ if value != value: # Check for NaN return False if self.best_is == 'auto': if len(self.history) < 2: return True if self.history[0] > self.history[1]: self.best_is = 'lower' else: self.best_is = 'higher' if self.best_is == 'lower': return reference - self.min_delta > value return reference + self.min_delta < value
[docs] def is_improving(self) -> bool: """Determine if the model performance is improving. Returns: True if there has been an improvement, False otherwise. Side Effects: If there is no improvement, the patience countdown is reduced. Otherwise, it is restored to the maximum. """ if len(self.history) <= 1: return True aggregated_value = self.filtered_value if self.is_better(aggregated_value, self.best_value): self.best_value = aggregated_value self._patience_countdown = self.patience return True self._patience_countdown -= 1 return False
[docs] def is_patient(self) -> bool: """Check whether to be patient.""" return self._patience_countdown > 0
[docs] def reset_patience(self) -> None: """Reset patience countdown to the maximum.""" self._patience_countdown = self.patience return
@staticmethod def _validate_patience(patience: int) -> None: if patience < 0: raise ValueError('Patience must be a non-negative integer.') return
def check_device( calculator: p.ObjectiveProtocol[Any, Any], device: torch.device ) -> None: """Check the metrics returned by the calculator are on the given device. Args: calculator: An ObjectiveProtocol instance to check. device: The device to check against. """ metrics = calculator.compute() if isinstance(metrics, Mapping): for name, value in metrics.items(): if value.device.type != device.type: raise exceptions.DeviceMismatchError(name, value.device, device) elif isinstance(metrics, Tensor) and metrics.device.type != device.type: name = calculator.__class__.__name__ raise exceptions.DeviceMismatchError(name, metrics.device, device) return
[docs] def compute_metrics( calculator: p.ObjectiveProtocol[Any, Any], ) -> Mapping[str, float]: """Compute and represent the metrics as a mapping of named values. Args: calculator: An ObjectiveProtocol instance from which to compute metrics. Returns: A mapping of metric names to their float values. """ computed_metrics = calculator.compute() if isinstance(computed_metrics, Mapping): return {name: value.item() for name, value in computed_metrics.items()} if isinstance(computed_metrics, Tensor): return {calculator.__class__.__name__: computed_metrics.item()} raise exceptions.ComputedMetricsTypeError(type(computed_metrics))
def dict_apply( dict_fn: dict[str, Callable[[Output, Target], Tensor]], outputs: Output, targets: Target, ) -> dict[str, Tensor]: """Apply the given tensor callables to the provided outputs and targets. Args: dict_fn: a dictionary of named callables (outputs, targets) -> Tensor. outputs: the outputs to apply the tensor callables to. targets: the targets to apply the tensor callables to. Returns: A dictionary containing the resulting values. """ return { name: function(outputs, targets) for name, function in dict_fn.items() }