"""Module containing classes to create and combine loss and metrics.
The interface is similar to https://github.com/Lightning-AI/torchmetrics,
with stricter typing and simpler construction. MetricCollection and
CompositionalMetric from torchmetrics change their state; here a functional
approach is preferred.
"""
from __future__ import annotations
import abc
import copy
import operator
import warnings
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Final, Generic, Literal, Self, TypeVar
import torch
from typing_extensions import override
from drytorch.core import exceptions
from drytorch.core import protocols as p
from drytorch.lib import aggregators
__all__ = [
'AverageObjective',
'CompositionalLoss',
'JoinLossMetrics',
'JoinMetrics',
'Loss',
'LossBase',
'Metric',
'MetricCollection',
'MetricTracker',
'Objective',
'compute_metrics',
]
Output = TypeVar('Output', bound=p.OutputType, contravariant=True)
Target = TypeVar('Target', bound=p.TargetType, contravariant=True)
Tensor = torch.Tensor
[docs]
class Objective(p.ObjectiveProtocol[Output, Target], metaclass=abc.ABCMeta):
"""Abstract base class for metrics or losses."""
def __init__(self) -> None:
"""Initialize."""
self._aggregator = self._get_aggregator()
return
[docs]
@override
def compute(self: Self) -> dict[str, Tensor]:
"""Return the aggregated objective value(s).
Despite the name, which follows common practice, this method caches
previous computed values and returns them if available.
Returns:
A dictionary of computed metric values.
"""
if not self._aggregator:
warnings.warn(
exceptions.ComputedBeforeUpdatedWarning(self), stacklevel=1
)
return self._compute()
[docs]
@override
def update(
self: Self, outputs: Output, targets: Target
) -> dict[str, Tensor]:
"""Updates the objective's internal state with new outputs and targets.
Args:
outputs: the model outputs.
targets: the ground truth targets.
Returns:
A dictionary of the calculated metric values for the current update.
"""
results = self.calculate(outputs, targets)
self._aggregator += results
return results
[docs]
@override
def reset(self: Self) -> None:
"""Resets the internal state of the instance."""
self._aggregator.clear()
return
[docs]
@abc.abstractmethod
def calculate(
self: Self, outputs: Output, targets: Target
) -> dict[str, Tensor]:
"""Method responsible for the calculations.
Args:
outputs: model outputs.
targets: ground truth.
Returns:
A dictionary of calculated metric values.
"""
[docs]
def copy(self) -> Self:
"""Create a (deep)copy of self."""
return copy.deepcopy(self, {})
[docs]
def merge_state(self: Self, other: Self) -> None:
"""Merge metric states.
Args:
other: metric to be merged with.
"""
self._aggregator += other._aggregator
return
[docs]
def sync(self: Self) -> None:
"""Synchronize metric states across processes."""
self._aggregator.all_reduce()
return
[docs]
def __deepcopy__(self, memo: dict[int, Any]) -> Self:
"""Deep copy magic method.
Args:
memo: dictionary of already copied objects.
Returns:
A deep copy of the object.
"""
cls = self.__class__
result = cls.__new__(cls)
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
return result
@abc.abstractmethod
def _compute(self: Self) -> dict[str, Tensor]:
"""Computes the objective value(s)."""
@classmethod
@abc.abstractmethod
def _get_aggregator(cls) -> aggregators.AbstractAggregator[Any, Any]:
"""Returns the aggregator class."""
[docs]
class AverageObjective(Objective[Output, Target], metaclass=abc.ABCMeta):
"""Class defining the default aggregation."""
@override
def _compute(self: Self) -> dict[str, Tensor]:
"""Computes the objective value(s)."""
return self._aggregator.reduce()
@classmethod
def _get_aggregator(cls) -> aggregators.AbstractAggregator[Tensor, Tensor]:
return aggregators.TorchAverager()
[docs]
class MetricCollection(AverageObjective[Output, Target]):
"""A collection of multiple metrics.
Attributes:
named_fn: dictionary of named functions to calculate.
"""
named_fn: dict[str, Callable[[Output, Target], Tensor]]
def __init__(
self,
**named_fn: Callable[[Output, Target], Tensor],
) -> None:
"""Initialize.
Args:
**named_fn: dictionary of named functions to calculate.
"""
super().__init__()
self.named_fn: Final = named_fn
return
[docs]
@override
def calculate(self, outputs: Output, targets: Target) -> dict[str, Tensor]:
"""Calculates the values for all metrics in the collection.
Args:
outputs: the model outputs.
targets: the ground truth targets.
Returns:
A dictionary of calculated metric values.
"""
return dict_apply(self.named_fn, outputs, targets)
[docs]
def __or__(
self, other: MetricCollection[Output, Target]
) -> MetricCollection[Output, Target]:
"""Constructor using existing MetricCollection objects as templates.
This class does not aggregate the states. If you intend to do this,
use the merge_state method separately.
Args:
other: another MetricCollection object to combine with.
Returns:
A new instance containing metrics from both instances.
"""
named_fn = self.named_fn | other.named_fn
return MetricCollection(**named_fn)
[docs]
class Metric(MetricCollection[Output, Target]):
"""Subclass for a single metr.
Attributes:
fun: the callable that computes the metric value.
name: identifier for the metric.
higher_is_better: True if higher values indicate better performance.
"""
fun: Callable[[Output, Target], Tensor]
name: str
higher_is_better: bool | None
def __init__(
self,
fn: Callable[[Output, Target], Tensor],
/,
name: str,
higher_is_better: bool | None = None,
) -> None:
"""Initialize.
Args:
fn: the callable that computes the metric value.
name: identifier for the metric.
higher_is_better: True if higher values indicate better performance,
False if lower values are better, None if unspecified.
"""
super().__init__(**{name: fn})
self.fun: Final = fn
self.name: Final = name
self.higher_is_better: Final = higher_is_better
return
[docs]
class LossBase(
MetricCollection[Output, Target],
p.LossProtocol[Output, Target],
metaclass=abc.ABCMeta,
):
"""Collection of metrics, one of which serves as a loss.
Attributes:
name: identifier for the loss.
higher_is_better: True if higher values indicate better performance.
formula: string representation of the loss formula.
criterion: logic extracting a loss value from computed value.
"""
name: str
higher_is_better: bool
formula: str
criterion: Callable[[dict[str, Tensor]], Tensor]
def __init__(
self,
criterion: Callable[[dict[str, Tensor]], Tensor],
name: str,
higher_is_better: bool = False,
formula: str = '',
**named_fn: Callable[[Output, Target], Tensor],
) -> None:
"""Initialize.
Args:
criterion: logic extracting a loss value from computed value.
name: identifier for the loss.
higher_is_better: True if higher values indicate better performance,
False if lower values are better.
formula: string representation of the loss formula.
**named_fn: dictionary of named functions to calculate.
"""
self.name: Final = name
self.higher_is_better: Final = higher_is_better
self.formula: Final = formula
super().__init__(**named_fn)
self.criterion: Final = criterion
return
[docs]
@override
def forward(self, outputs: Output, targets: Target) -> Tensor:
"""Performs a forward pass, updates metrics, and computes the loss.
Args:
outputs: the model outputs.
targets: the ground truth targets.
Returns:
The computed loss value.
"""
metrics = self.update(outputs, targets)
return self.criterion(metrics).mean()
[docs]
def watch(self, metric: MetricCollection[Output, Target]) -> None:
"""Include another Objective class in its metrics.
Args:
metric: the other Objective to watch.
"""
self.named_fn.update(metric.named_fn)
return
def _combine(
self,
other: LossBase[Output, Target] | float,
operation: Callable[[Tensor, Tensor], Tensor],
op_fmt: str,
requires_parentheses: bool = True,
) -> CompositionalLoss[Output, Target]:
"""Support operations between losses or a loss and a float.
Args:
other: the other loss or float to combine with.
operation: the callable operation to apply (e.g., operator.add).
op_fmt: the format string for the combined formula.
requires_parentheses: whether to wrap sub-formulas in parentheses.
Returns:
A new CompositionalLoss representing the combined loss.
"""
if isinstance(other, LossBase):
named_fn = self.named_fn | other.named_fn
str_first = self.formula
str_second = other.formula
# apply should combine losses that share the same direction
self._check_same_direction(other)
def _combined(x: dict[str, Tensor]) -> Tensor:
return operation(self.criterion(x), other.criterion(x))
elif isinstance(other, float | int):
named_fn = self.named_fn
str_first = str(other)
str_second = self.formula
def _combined(x: dict[str, Tensor]) -> Tensor:
return operation(self.criterion(x), torch.tensor(other))
else:
raise TypeError(f'Unsupported type for operation: {type(other)}')
if not requires_parentheses:
str_first = self._remove_outer_parentheses(str_first)
str_second = self._remove_outer_parentheses(str_second)
formula = op_fmt.format(str_first, str_second)
return CompositionalLoss(
criterion=_combined,
higher_is_better=self.higher_is_better,
name='Combined Loss',
formula=formula,
**named_fn,
)
[docs]
def __neg__(self) -> CompositionalLoss[Output, Target]:
"""Constructor from an existing template.
Returns:
A new CompositionalLoss representing the negated loss.
"""
return CompositionalLoss(
criterion=lambda x: -self.criterion(x),
higher_is_better=not self.higher_is_better,
name='Negative ' + self.name,
formula=f'-{self.formula}',
**self.named_fn,
)
[docs]
def __add__(
self,
other: LossBase[Output, Target] | float,
) -> CompositionalLoss[Any, Any]:
"""Constructor from exiting templates.
Args:
other: the other loss or float to add.
Returns:
A new CompositionalLoss representing the sum.
"""
if other == 0 and isinstance(self, CompositionalLoss):
return self
return self._combine(other, operator.add, '{} + {}', False)
[docs]
def __radd__(self, other: float) -> CompositionalLoss[Any, Any]:
"""Constructor from exiting templates.
Args:
other: the float to add to the loss.
Returns:
A new CompositionalLoss representing the sum.
"""
return self.__add__(other)
[docs]
def __sub__(
self,
other: LossBase[Output, Target] | float,
) -> CompositionalLoss[Output, Target]:
"""Constructor from exiting templates.
Args:
other: the other loss or float to subtract.
Returns:
A new CompositionalLoss representing the difference.
"""
neg_other = other.__neg__()
return self.__add__(neg_other)
[docs]
def __rsub__(self, other: float) -> CompositionalLoss[Output, Target]:
"""Constructor from exiting templates.
Args:
other: the float from which to subtract the loss.
Returns:
A new CompositionalLoss representing the difference.
"""
neg_self = self.__neg__()
return neg_self.__add__(other)
[docs]
def __mul__(
self,
other: LossBase[Output, Target] | float,
) -> CompositionalLoss[Output, Target]:
"""Constructor from exiting templates.
Args:
other: the other loss or float to multiply by.
Returns:
A new CompositionalLoss representing the product.
"""
if other == 1 and isinstance(self, CompositionalLoss):
return self
return self._combine(other, operator.mul, '{} x {}')
[docs]
def __rmul__(self, other: float) -> CompositionalLoss[Output, Target]:
"""Constructor from exiting templates.
Args:
other: the float to multiply the loss by.
Returns:
A new CompositionalLoss representing the product.
"""
return self.__mul__(other)
[docs]
def __truediv__(
self,
other: LossBase[Output, Target] | float,
) -> CompositionalLoss[Output, Target]:
"""Constructor from exiting templates.
Args:
other: the other loss or float to divide by.
Returns:
A new CompositionalLoss representing the quotient.
"""
if other == 1 and isinstance(self, CompositionalLoss):
return self
mul_inv_other = other.__pow__(-1)
return self.__mul__(mul_inv_other)
[docs]
def __rtruediv__(self, other: float) -> CompositionalLoss[Output, Target]:
"""Constructor from exiting templates.
Args:
other: the float to be divided by the loss.
Returns:
A new CompositionalLoss representing the quotient.
"""
mul_inv_self = self.__pow__(-1)
return mul_inv_self.__mul__(other)
[docs]
def __pow__(self, other: float) -> CompositionalLoss[Output, Target]:
"""Constructor from exiting templates.
Args:
other: the power to raise the loss to.
Returns:
A new CompositionalLoss representing the result.
"""
def _to_floating_point(x: Tensor) -> Tensor:
return x if torch.is_floating_point(x) else x.float()
if other == 1 and isinstance(self, CompositionalLoss):
return self
elif other == -1:
higher_is_better = not self.higher_is_better
formula = f'1 / {self.formula}'
elif other >= 0:
higher_is_better = self.higher_is_better
formula = f'{self.formula}^{other}'
else:
higher_is_better = not self.higher_is_better
formula = f'1 / {self.formula}^{-other}'
return CompositionalLoss(
criterion=lambda x: _to_floating_point(self.criterion(x)) ** other,
higher_is_better=higher_is_better,
name='Loss',
formula=formula,
**self.named_fn,
)
[docs]
def __repr__(self):
"""Returns the string representation of the LossBase object."""
return f'{self.__class__.__name__}({self.formula})'
def _check_same_direction(self, other: LossBase[Output, Target]) -> None:
"""Checks if two losses have the same optimization direction.
Args:
other: the other LossBase object to compare with.
Raises:
ValueError: If the losses have opposite directions for optimization.
"""
if self.higher_is_better ^ other.higher_is_better:
msg = 'Losses {} and {} have opposite directions for optimizations.'
raise ValueError(msg.format(self, other))
return
@staticmethod
def _remove_outer_parentheses(formula: str) -> str:
"""Removes outer parentheses from a formula string if present.
Args:
formula: the formula string.
Returns:
The formula string without outer parentheses.
"""
if formula.startswith('(') and formula.endswith(')'):
return formula[1:-1]
if formula.startswith('[]') and formula.endswith(']'):
return formula[1:-1]
return formula
[docs]
class CompositionalLoss(
LossBase[Output, Target],
):
"""Loss resulting from an operation between other two losses."""
def __init__(
self,
criterion: Callable[[dict[str, Tensor]], Tensor],
*,
name='Loss',
higher_is_better: bool,
formula: str = '',
**named_fn: Callable[[Output, Target], Tensor],
) -> None:
"""Initialize.
Args:
criterion: function extracting a loss value from metric functions.
name: identifier for the loss.
higher_is_better: True if higher values indicate better performance,
False if lower values are better.
formula: string representation of the loss formula.
named_fn: dictionary of named metric functions.
"""
super().__init__(
criterion,
name,
higher_is_better,
**named_fn,
formula=self._format_formula(formula),
)
return
[docs]
@override
def calculate(
self: Self, outputs: Output, targets: Target
) -> dict[str, Tensor]:
"""Calculates the loss and all associated metric values.
Args:
outputs: the model outputs.
targets: the ground truth targets.
Returns:
A dictionary containing the calculated loss and metric values.
"""
all_metrics = super().calculate(outputs, targets)
return {self.name: self.criterion(all_metrics)} | all_metrics
@staticmethod
def _format_formula(formula: str) -> str:
"""Simplifies the formula string by removing redundant characters.
Args:
formula: the formula string.
Returns:
The simplified formula string.
"""
formula = formula.replace('--', '').replace('+ -', '- ')
if formula.startswith('(') and formula.endswith(')'):
return formula
if formula.startswith('[') and formula.endswith(']'):
return formula
return '(' + formula + ')'
[docs]
class Loss(CompositionalLoss[Output, Target]):
"""Subclass for simple losses with a convenient constructor."""
def __init__(
self,
fn: Callable[[Output, Target], Tensor],
/,
name: str,
higher_is_better: bool = False,
):
"""Initialize.
Args:
fn: the callable to calculate the loss.
name: the name for the loss.
higher_is_better: the direction for optimization.
"""
super().__init__(
operator.itemgetter(name),
name=name,
higher_is_better=higher_is_better,
formula=f'[{name}]',
**{name: fn},
)
return
[docs]
class JoinMetrics(p.ObjectiveProtocol[Output, Target]):
"""Wrapper that joins two metrics.
Preferably, use :py:meth:`MetricCollection.__or__` when both classes are
:py:meth:`MetricCollection`.
"""
def __init__(
self,
metric_a: Objective[Output, Target],
metric_b: Objective[Output, Target],
/,
) -> None:
"""Initialize.
Args:
metric_a: first objective.
metric_b: second objective.
"""
super().__init__()
self.metric_a = metric_a
self.metric_b = metric_b
return
[docs]
@override
def compute(self: Self) -> dict[str, torch.Tensor]:
"""Return the aggregated values of both metrics.
Returns:
A dictionary merging the computed values of both metrics.
"""
return self.metric_a.compute() | self.metric_b.compute()
[docs]
@override
def reset(self) -> None:
"""Reset the internal state of both metrics."""
self.metric_a.reset()
self.metric_b.reset()
return
[docs]
def sync(self: Self) -> None:
"""Synchronize metric states across processes."""
self.metric_a.sync()
self.metric_b.sync()
return
[docs]
@override
def update(
self: Self, outputs: Output, targets: Target
) -> dict[str, torch.Tensor]:
"""Update both metrics with new outputs and targets.
Args:
outputs: the model outputs.
targets: the ground truth targets.
Returns:
A dictionary merging the calculated values of both metrics.
"""
metric_a_update = self.metric_a.update(outputs, targets)
metric_b_update = self.metric_b.update(outputs, targets)
return metric_a_update | metric_b_update
[docs]
class JoinLossMetrics(JoinMetrics, p.LossProtocol[Output, Target]):
"""Loss resulting from adding an extra metric to a loss.
Preferably, use :py:meth:`LossBase.watch` when the metric is a
:py:meth:`MetricCollection`.
"""
def __init__(
self,
loss: LossBase[Output, Target],
objective: Objective[Output, Target],
/,
) -> None:
"""Initialize.
Args:
loss: the primary loss.
objective: the extra metric to track alongside the loss.
"""
super().__init__(loss, objective)
self.loss = loss
self.name = self.loss.name
return
[docs]
@override
def forward(self, outputs: Output, targets: Target, /) -> torch.Tensor:
"""Compute and return the loss.
Args:
outputs: the model outputs.
targets: the ground truth targets.
Returns:
The computed loss.
"""
return self.loss.forward(outputs, targets)
[docs]
class MetricTracker(Generic[Output, Target]):
"""Handle metric value tracking and improvement detection.
This class is responsible for storing metric history, determining
improvements, and managing patience countdown.
Note: this class can be used to automatically modify the training strategy.
Therefore, it does not follow the library conventions for a tracker.
Attributes:
metric_name: the name of the metric to monitor.
min_delta: the minimum change required to qualify as an improvement.
patience: number of checks to wait before triggering callback.
best_is: whether higher or lower values are better.
filter_fn: function to aggregate recent metric values.
history: logs of the recorded metrics.
"""
metric_name: str | None
best_is: Literal['auto', 'higher', 'lower']
filter_fn: Callable[[Sequence[float]], float]
min_delta: float
patience: int
history: list[float]
_patience_countdown: int
_best_value: float | None
def __init__(
self,
metric_name: str | None = None,
min_delta: float = 1e-8,
patience: int = 0,
best_is: Literal['auto', 'higher', 'lower'] = 'auto',
filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1),
) -> None:
"""Initialize.
Args:
metric_name: name of the metric to track.
min_delta: minimum change required to qualify as an improvement.
patience: number of checks to wait before triggering callback.
best_is: whether higher or lower metric values are better.
filter_fn: function to aggregate recent metric values.
"""
self.metric_name = metric_name
self.best_is = best_is
self.filter_fn: Final = filter_fn
self.min_delta: Final = min_delta
self._validate_patience(patience)
self.patience: Final = patience
self.history: Final = list[float]()
self._patience_countdown = patience
self._best_value = None
return
@property
def best_value(self) -> float:
"""Get the best result observed so far.
Returns:
the best filtered value (according to the 'best_is' criterion).
Raises:
ResultNotAvailableError: if no results have been logged yet.
"""
if self._best_value is None:
try:
self._best_value = self.history[0]
except IndexError as ie:
raise exceptions.ResultNotAvailableError() from ie
return self._best_value
@best_value.setter
def best_value(self, value: float) -> None:
"""Set the best result value."""
self._best_value = value
return
@property
def filtered_value(self) -> float:
"""Get the current value.
Returns:
the current value aggregated from recent ones.
Raises:
ResultNotAvailableError: if no results have been logged yet.
"""
return self.filter_fn(self.history)
[docs]
def add_value(self, value: float) -> None:
"""Add a new metric value to the history.
Args:
value: the metric value to add.
"""
self.history.append(value)
return
[docs]
def is_better(self, value: float, reference: float) -> bool:
"""Determine if the value is better than a reference value.
When best_is is in 'auto' mode, it is assumed that the given value is
better than the first recorded one.
Args:
value: the value to compare.
reference: the reference.
Returns:
True if value is a potential improvement, False otherwise.
"""
if value != value: # Check for NaN
return False
if self.best_is == 'auto':
if len(self.history) < 2:
return True
if self.history[0] > self.history[1]:
self.best_is = 'lower'
else:
self.best_is = 'higher'
if self.best_is == 'lower':
return reference - self.min_delta > value
return reference + self.min_delta < value
[docs]
def is_improving(self) -> bool:
"""Determine if the model performance is improving.
Returns:
True if there has been an improvement, False otherwise.
Side Effects:
If there is no improvement, the patience countdown is reduced.
Otherwise, it is restored to the maximum.
"""
if len(self.history) <= 1:
return True
aggregated_value = self.filtered_value
if self.is_better(aggregated_value, self.best_value):
self.best_value = aggregated_value
self._patience_countdown = self.patience
return True
self._patience_countdown -= 1
return False
[docs]
def is_patient(self) -> bool:
"""Check whether to be patient."""
return self._patience_countdown > 0
[docs]
def reset_patience(self) -> None:
"""Reset patience countdown to the maximum."""
self._patience_countdown = self.patience
return
@staticmethod
def _validate_patience(patience: int) -> None:
if patience < 0:
raise ValueError('Patience must be a non-negative integer.')
return
def check_device(
calculator: p.ObjectiveProtocol[Any, Any], device: torch.device
) -> None:
"""Check the metrics returned by the calculator are on the given device.
Args:
calculator: An ObjectiveProtocol instance to check.
device: The device to check against.
"""
metrics = calculator.compute()
if isinstance(metrics, Mapping):
for name, value in metrics.items():
if value.device.type != device.type:
raise exceptions.DeviceMismatchError(name, value.device, device)
elif isinstance(metrics, Tensor) and metrics.device.type != device.type:
name = calculator.__class__.__name__
raise exceptions.DeviceMismatchError(name, metrics.device, device)
return
[docs]
def compute_metrics(
calculator: p.ObjectiveProtocol[Any, Any],
) -> Mapping[str, float]:
"""Compute and represent the metrics as a mapping of named values.
Args:
calculator: An ObjectiveProtocol instance from which to compute metrics.
Returns:
A mapping of metric names to their float values.
"""
computed_metrics = calculator.compute()
if isinstance(computed_metrics, Mapping):
return {name: value.item() for name, value in computed_metrics.items()}
if isinstance(computed_metrics, Tensor):
return {calculator.__class__.__name__: computed_metrics.item()}
raise exceptions.ComputedMetricsTypeError(type(computed_metrics))
def dict_apply(
dict_fn: dict[str, Callable[[Output, Target], Tensor]],
outputs: Output,
targets: Target,
) -> dict[str, Tensor]:
"""Apply the given tensor callables to the provided outputs and targets.
Args:
dict_fn: a dictionary of named callables (outputs, targets) -> Tensor.
outputs: the outputs to apply the tensor callables to.
targets: the targets to apply the tensor callables to.
Returns:
A dictionary containing the resulting values.
"""
return {
name: function(outputs, targets) for name, function in dict_fn.items()
}