Metrics and Losses

Open In Colab

DRYTorch helps you standardize and document your model’s metrics and loss.

Design

The modular design extends to metrics and losses. DRYTorch provides a common interface for both, allowing you to easily switch between different libraries.

Terminology

An objective is a criterion for model performance evaluation. We distinguish between two types:

  • Metric: Assesses model performance as a proxy for the overall goal.

  • Loss: Optimizes the model parameters to improve metric assessments.

DRYTorch allows using losses as metrics, but not vice versa.

Protocols

DRYTorch defines an ObjectiveProtocol, used by classes that implement the validation and testing of a model, and a LossProtocol, used for its training.

! uv pip install drytorch

Compatibility with Existing Libraries

DRYTorch does not re-implement common metrics or losses. Instead, it defines protocols to ensure full compatibility with classes from popular existing libraries.

For Validation and Testing

The ObjectiveProtocol is compatible with Metric classes from:

You can use instances of these third-party metric classes directly when defining a DRYTorch validation or test step.

For Training

The LossProtocol is designed to accept any class that meets its requirements, including some metrics. You can therefore use differentiable metrics from libraries like TorchMetrics directly when building a DRYTorch training class.

TorchMetrics also offers a CompositionalMetric, with support for algebra operations, which inspired part of the DRYTorch own implementation. To make it compatible with the framework and better documentation, you can use from_torchmetrics.

import torch
import torchmetrics

from torcheval import metrics as eval_metrics

from drytorch.core import protocols as p


tensor_a = torch.ones(1, 1, dtype=torch.float)
tensor_b = 3 * torch.ones(1, 1, dtype=torch.float)
torch_metric = torchmetrics.MeanSquaredError()
eval_metric = eval_metrics.MeanSquaredError()


def is_valid_objective(
    metric: p.ObjectiveProtocol[torch.Tensor, torch.Tensor],
) -> bool:
    """Test metric follows the Objective protocol."""
    return isinstance(metric, p.ObjectiveProtocol)


torch_metric.update(tensor_a, tensor_b)
eval_metric.update(tensor_a, tensor_b)

if not torch.isclose(torch_metric.compute(), eval_metric.compute()):
    raise AssertionError('Metrics values should match.')

if not (is_valid_objective(eval_metric) and is_valid_objective(torch_metric)):
    raise AssertionError('These objects should follow the ObjectiveProtocol.')
def is_valid_loss(
    metric: p.LossProtocol[torch.Tensor, torch.Tensor],
) -> bool:
    """Test metric follows the Loss protocol."""
    return isinstance(metric, p.LossProtocol)


if not is_valid_loss(torch_metric):
    raise AssertionError('This object should also follow the LossProtocol.')
from drytorch.contrib.torchmetrics import from_torchmetrics


new_metric = 1 + torch_metric
imported_metric = from_torchmetrics(new_metric)
imported_metric.update(tensor_a, tensor_b)
expected_metrics_from_torchmetrics = {
    'Combined Loss': torch.tensor(5.0),
    'MeanSquaredError': torch.tensor(4.0),
}
if not imported_metric.compute() == expected_metrics_from_torchmetrics:
    raise AssertionError('Metrics values should be as expected.')

DRYTorch implementation

DRYTorch objective classes act as wrappers around user-defined metric and loss callables.

These callables must accept model outputs and targets as arguments and return a scalar PyTorch Tensor for an aggregated mini-batch value or a vector of batched values (recommended for more precise averaging across batches of varying sizes). The abstract Objective class handles calling the logic, documenting, and correctly aggregating the results across batches.

The Metric and MetricCollection classes

The Metric class is to define a single metric. You can document it by giving it an explicit name and specifies whether it is better when higher or lower. You can also concatenate different Metric instances with compatible signatures into a MetricCollection instance, or creating one directly from a dictionary of named metric functions.

from torch.nn.functional import mse_loss as mse_loss_fn  # returns scalar value

from drytorch.lib.objectives import Metric


def mae_loss_fn(outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """Returns batched Meas Absolute Error (MAE) values."""
    return torch.abs(outputs - targets).flatten(1).mean(1)


mse_metric = Metric(mse_loss_fn, name='MSE', higher_is_better=False)
mae_metric = Metric(mae_loss_fn, 'MAE', higher_is_better=False)
metric_collection = mse_metric | mae_metric
metric_collection.update(tensor_a, tensor_b)
metric_collection.compute()
expected_metric_collection = {
    'MSE': torch.tensor(4.0),
    'MAE': torch.tensor(2.0),
}
if not metric_collection.compute() == expected_metric_collection:
    raise AssertionError('Metrics values should be as expected.')

Define a Custom Metric class

You can subclass the abstract AverageObjective class by overriding the calculate method. For different aggregation strategies, subclass the Objective and override the _compute and _get_aggregator method. In this example, we slightly reduce the calculation overhead to obtain the previous metrics.

from typing_extensions import override

from drytorch.lib.objectives import AverageObjective


class MyMetrics(AverageObjective[torch.Tensor, torch.Tensor]):
    """Class to calculate MSE and MAE more efficiently."""

    @override
    def calculate(
        self, outputs: torch.Tensor, targets: torch.Tensor
    ) -> dict[str, torch.Tensor]:
        diff = outputs - targets
        return {
            'MSE': torch.pow(diff, 2).flatten(1).mean(1),
            'MAE': torch.abs(diff).flatten(1).mean(1),
        }


my_metrics = MyMetrics()
my_metrics.update(tensor_a, tensor_b)
my_metrics.compute()
if not my_metrics.compute() == expected_metric_collection:
    raise AssertionError('Metrics values should be as before.')

LossBase, Loss and CompositionalLoss

LossBase is the abstract class for concrete loss classes, such as Loss and CompositionalLoss.

Loss is equivalent to Metric and accepts a single callable that is used both as a criterion for backpropagation for the loss and as a metric.

The CompositionalLoss class extends this idea by evaluating other metrics besides the main optimization criterion. This allows you to easily document and track the performance of the single components that make up a more complex, composed loss function.

It is possible to create a compositional loss by using simple algebraic operations between a LossBase instance and an integer, float, or another LossBase instance. The resulting object’s formula attribute documents the specific operations and component losses utilized.

from torch.nn.functional import mse_loss as mse_loss_fn  # returns scalar value

from drytorch.lib.objectives import Loss


mse_loss = Loss(mse_loss_fn, name='MSE')
mae_loss = Loss(mae_loss_fn, 'MAE')
composed_loss = mse_loss**2 + 0.5 * mae_loss
composed_loss.update(tensor_a, tensor_b)
expected_metrics_from_loss = {
    'Combined Loss': torch.tensor(17.0),
    'MSE': torch.tensor(4.0),
    'MAE': torch.tensor(2.0),
}
if not composed_loss.compute() == expected_metrics_from_loss:
    raise AssertionError('Metrics values should be as expected.')
if composed_loss.formula != '[MSE]^2 + 0.5 x [MAE]':
    raise AssertionError('Formula mismatch.')

Data Distributed Parallelism

DRYTorch Objective classes are compatible with PyTorch’s Data Distributed Parallelism (DDP) module. Synchronization is handled by the library classes.

To use torchmetrics and torcheval metrics with DDP, we recommend using the from_torchmetrics and from_torcheval utility functions.

In particular, from_torchmetrics deactivates automatic synchronization and from_torcheval adds a sync method that calls torcheval.metrics.toolkit.sync_and_compute to synchronize the metrics across all processes.

The following code snippet shows the latter call, which will raise a warning as the current process is not in a DDP scenario.

from drytorch.contrib.torcheval import from_torcheval


eval_metric_with_sync = from_torcheval(eval_metric)

eval_metric_with_sync.sync()
World size is 1, and metric(s) not synced. returning the input metric(s).