drytorch.lib.objectives

Module containing classes to create and combine loss and metrics.

The interface is similar to https://github.com/Lightning-AI/torchmetrics, with stricter typing and simpler construction. MetricCollection and CompositionalMetric from torchmetrics change their state; here a functional approach is preferred.

Functions

check_device(calculator, device)

Check the metrics returned by the calculator are on the given device.

compute_metrics(calculator)

Compute and represent the metrics as a mapping of named values.

dict_apply(dict_fn, outputs, targets)

Apply the given tensor callables to the provided outputs and targets.

Classes

AverageObjective()

Class defining the default aggregation.

CompositionalLoss(criterion, *[, name, formula])

Loss resulting from an operation between other two losses.

JoinLossMetrics(loss, objective, /)

Loss resulting from adding an extra metric to a loss.

JoinMetrics(metric_a, metric_b, /)

Wrapper that joins two metrics.

Loss(fn, /, name[, higher_is_better])

Subclass for simple losses with a convenient constructor.

LossBase(criterion, name[, ...])

Collection of metrics, one of which serves as a loss.

Metric(fn, /, name[, higher_is_better])

Subclass for a single metr.

MetricCollection(**named_fn)

A collection of multiple metrics.

MetricTracker([metric_name, min_delta, ...])

Handle metric value tracking and improvement detection.

Objective()

Abstract base class for metrics or losses.

class AverageObjective[source]

Bases: Objective[Output, Target]

Class defining the default aggregation.

Initialize.

class CompositionalLoss(criterion: Callable[[dict[str, Tensor]], Tensor], *, name='Loss', higher_is_better: bool, formula: str = '', **named_fn: Callable[[Output, Target], Tensor])[source]

Bases: LossBase[Output, Target]

Loss resulting from an operation between other two losses.

Initialize.

Parameters:
  • criterion (Callable[[dict[str, Tensor]], Tensor]) – function extracting a loss value from metric functions.

  • name (str) – identifier for the loss.

  • higher_is_better (bool) – True if higher values indicate better performance, False if lower values are better.

  • formula (str) – string representation of the loss formula.

  • named_fn (dict[str, Callable[[Output, Target], Tensor]]) – dictionary of named metric functions.

calculate(outputs: Output, targets: Target) dict[str, Tensor][source]

Calculates the loss and all associated metric values.

Parameters:
  • outputs (Output) – the model outputs.

  • targets (Target) – the ground truth targets.

  • self (Self)

Returns:

A dictionary containing the calculated loss and metric values.

Return type:

dict[str, Tensor]

class JoinLossMetrics(loss: LossBase[Output, Target], objective: Objective[Output, Target], /)[source]

Bases: JoinMetrics, LossProtocol[Output, Target]

Loss resulting from adding an extra metric to a loss.

Preferably, use LossBase.watch() when the metric is a MetricCollection().

Initialize.

Parameters:
  • loss (LossBase[Output, Target]) – the primary loss.

  • objective (Objective[Output, Target]) – the extra metric to track alongside the loss.

forward(outputs: Output, targets: Target, /) Tensor[source]

Compute and return the loss.

Parameters:
  • outputs (Output) – the model outputs.

  • targets (Target) – the ground truth targets.

Returns:

The computed loss.

Return type:

Tensor

class JoinMetrics(metric_a: Objective[Output, Target], metric_b: Objective[Output, Target], /)[source]

Bases: ObjectiveProtocol[Output, Target]

Wrapper that joins two metrics.

Preferably, use MetricCollection.__or__() when both classes are MetricCollection().

Initialize.

Parameters:
  • metric_a (Objective[Output, Target]) – first objective.

  • metric_b (Objective[Output, Target]) – second objective.

compute() dict[str, Tensor][source]

Return the aggregated values of both metrics.

Returns:

A dictionary merging the computed values of both metrics.

Parameters:

self (Self)

Return type:

dict[str, Tensor]

reset() None[source]

Reset the internal state of both metrics.

Return type:

None

sync() None[source]

Synchronize metric states across processes.

Parameters:

self (Self)

Return type:

None

update(outputs: Output, targets: Target) dict[str, Tensor][source]

Update both metrics with new outputs and targets.

Parameters:
  • outputs (Output) – the model outputs.

  • targets (Target) – the ground truth targets.

  • self (Self)

Returns:

A dictionary merging the calculated values of both metrics.

Return type:

dict[str, Tensor]

class Loss(fn: Callable[[Output, Target], Tensor], /, name: str, higher_is_better: bool = False)[source]

Bases: CompositionalLoss[Output, Target]

Subclass for simple losses with a convenient constructor.

Initialize.

Parameters:
  • fn (Callable[[Output, Target], Tensor]) – the callable to calculate the loss.

  • name (str) – the name for the loss.

  • higher_is_better (bool) – the direction for optimization.

class LossBase(criterion: Callable[[dict[str, Tensor]], Tensor], name: str, higher_is_better: bool = False, formula: str = '', **named_fn: Callable[[Output, Target], Tensor])[source]

Bases: MetricCollection[Output, Target], LossProtocol[Output, Target]

Collection of metrics, one of which serves as a loss.

name

identifier for the loss.

Type:

str

higher_is_better

True if higher values indicate better performance.

Type:

bool

formula

string representation of the loss formula.

Type:

str

criterion

logic extracting a loss value from computed value.

Type:

collections.abc.Callable[[dict[str, torch.Tensor]], torch.Tensor]

Initialize.

Parameters:
  • criterion (Callable[[dict[str, Tensor]], Tensor]) – logic extracting a loss value from computed value.

  • name (str) – identifier for the loss.

  • higher_is_better (bool) – True if higher values indicate better performance, False if lower values are better.

  • formula (str) – string representation of the loss formula.

  • **named_fn (dict[str, Callable[[Output, Target], Tensor]]) – dictionary of named functions to calculate.

forward(outputs: Output, targets: Target) Tensor[source]

Performs a forward pass, updates metrics, and computes the loss.

Parameters:
  • outputs (Output) – the model outputs.

  • targets (Target) – the ground truth targets.

Returns:

The computed loss value.

Return type:

Tensor

watch(metric: MetricCollection[Output, Target]) None[source]

Include another Objective class in its metrics.

Parameters:

metric (MetricCollection[Output, Target]) – the other Objective to watch.

Return type:

None

__neg__() CompositionalLoss[Output, Target][source]

Constructor from an existing template.

Returns:

A new CompositionalLoss representing the negated loss.

Return type:

CompositionalLoss[Output, Target]

__add__(other: LossBase[Output, Target] | float) CompositionalLoss[Any, Any][source]

Constructor from exiting templates.

Parameters:

other (LossBase[Output, Target] | float) – the other loss or float to add.

Returns:

A new CompositionalLoss representing the sum.

Return type:

CompositionalLoss[Any, Any]

__radd__(other: float) CompositionalLoss[Any, Any][source]

Constructor from exiting templates.

Parameters:

other (float) – the float to add to the loss.

Returns:

A new CompositionalLoss representing the sum.

Return type:

CompositionalLoss[Any, Any]

__sub__(other: LossBase[Output, Target] | float) CompositionalLoss[Output, Target][source]

Constructor from exiting templates.

Parameters:

other (LossBase[Output, Target] | float) – the other loss or float to subtract.

Returns:

A new CompositionalLoss representing the difference.

Return type:

CompositionalLoss[Output, Target]

__rsub__(other: float) CompositionalLoss[Output, Target][source]

Constructor from exiting templates.

Parameters:

other (float) – the float from which to subtract the loss.

Returns:

A new CompositionalLoss representing the difference.

Return type:

CompositionalLoss[Output, Target]

__mul__(other: LossBase[Output, Target] | float) CompositionalLoss[Output, Target][source]

Constructor from exiting templates.

Parameters:

other (LossBase[Output, Target] | float) – the other loss or float to multiply by.

Returns:

A new CompositionalLoss representing the product.

Return type:

CompositionalLoss[Output, Target]

__rmul__(other: float) CompositionalLoss[Output, Target][source]

Constructor from exiting templates.

Parameters:

other (float) – the float to multiply the loss by.

Returns:

A new CompositionalLoss representing the product.

Return type:

CompositionalLoss[Output, Target]

__truediv__(other: LossBase[Output, Target] | float) CompositionalLoss[Output, Target][source]

Constructor from exiting templates.

Parameters:

other (LossBase[Output, Target] | float) – the other loss or float to divide by.

Returns:

A new CompositionalLoss representing the quotient.

Return type:

CompositionalLoss[Output, Target]

__rtruediv__(other: float) CompositionalLoss[Output, Target][source]

Constructor from exiting templates.

Parameters:

other (float) – the float to be divided by the loss.

Returns:

A new CompositionalLoss representing the quotient.

Return type:

CompositionalLoss[Output, Target]

__pow__(other: float) CompositionalLoss[Output, Target][source]

Constructor from exiting templates.

Parameters:

other (float) – the power to raise the loss to.

Returns:

A new CompositionalLoss representing the result.

Return type:

CompositionalLoss[Output, Target]

__repr__()[source]

Returns the string representation of the LossBase object.

class Metric(fn: Callable[[Output, Target], Tensor], /, name: str, higher_is_better: bool | None = None)[source]

Bases: MetricCollection[Output, Target]

Subclass for a single metr.

fun

the callable that computes the metric value.

Type:

collections.abc.Callable[[drytorch.lib.objectives.Output, drytorch.lib.objectives.Target], torch.Tensor]

name

identifier for the metric.

Type:

str

higher_is_better

True if higher values indicate better performance.

Type:

bool | None

Initialize.

Parameters:
  • fn (Callable[[Output, Target], Tensor]) – the callable that computes the metric value.

  • name (str) – identifier for the metric.

  • higher_is_better (bool | None) – True if higher values indicate better performance, False if lower values are better, None if unspecified.

class MetricCollection(**named_fn: Callable[[Output, Target], Tensor])[source]

Bases: AverageObjective[Output, Target]

A collection of multiple metrics.

named_fn

dictionary of named functions to calculate.

Type:

dict[str, collections.abc.Callable[[drytorch.lib.objectives.Output, drytorch.lib.objectives.Target], torch.Tensor]]

Initialize.

Parameters:

**named_fn (dict[str, Callable[[Output, Target], Tensor]]) – dictionary of named functions to calculate.

calculate(outputs: Output, targets: Target) dict[str, Tensor][source]

Calculates the values for all metrics in the collection.

Parameters:
  • outputs (Output) – the model outputs.

  • targets (Target) – the ground truth targets.

Returns:

A dictionary of calculated metric values.

Return type:

dict[str, Tensor]

__or__(other: MetricCollection[Output, Target]) MetricCollection[Output, Target][source]

Constructor using existing MetricCollection objects as templates.

This class does not aggregate the states. If you intend to do this, use the merge_state method separately.

Parameters:

other (MetricCollection[Output, Target]) – another MetricCollection object to combine with.

Returns:

A new instance containing metrics from both instances.

Return type:

MetricCollection[Output, Target]

class MetricTracker(metric_name: str | None = None, min_delta: float = 1e-08, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1))[source]

Bases: Generic[Output, Target]

Handle metric value tracking and improvement detection.

This class is responsible for storing metric history, determining improvements, and managing patience countdown.

Note: this class can be used to automatically modify the training strategy. Therefore, it does not follow the library conventions for a tracker.

metric_name

the name of the metric to monitor.

Type:

str | None

min_delta

the minimum change required to qualify as an improvement.

Type:

float

patience

number of checks to wait before triggering callback.

Type:

int

best_is

whether higher or lower values are better.

Type:

Literal[‘auto’, ‘higher’, ‘lower’]

filter_fn

function to aggregate recent metric values.

Type:

collections.abc.Callable[[collections.abc.Sequence[float]], float]

history

logs of the recorded metrics.

Type:

list[float]

Initialize.

Parameters:
  • metric_name (str | None) – name of the metric to track.

  • min_delta (float) – minimum change required to qualify as an improvement.

  • patience (int) – number of checks to wait before triggering callback.

  • best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better.

  • filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values.

property best_value: float

Get the best result observed so far.

Returns:

the best filtered value (according to the ‘best_is’ criterion).

Raises:

ResultNotAvailableError – if no results have been logged yet.

property filtered_value: float

Get the current value.

Returns:

the current value aggregated from recent ones.

Raises:

ResultNotAvailableError – if no results have been logged yet.

add_value(value: float) None[source]

Add a new metric value to the history.

Parameters:

value (float) – the metric value to add.

Return type:

None

is_better(value: float, reference: float) bool[source]

Determine if the value is better than a reference value.

When best_is is in ‘auto’ mode, it is assumed that the given value is better than the first recorded one.

Parameters:
  • value (float) – the value to compare.

  • reference (float) – the reference.

Returns:

True if value is a potential improvement, False otherwise.

Return type:

bool

is_improving() bool[source]

Determine if the model performance is improving.

Returns:

True if there has been an improvement, False otherwise.

Return type:

bool

Side Effects:

If there is no improvement, the patience countdown is reduced. Otherwise, it is restored to the maximum.

is_patient() bool[source]

Check whether to be patient.

Return type:

bool

reset_patience() None[source]

Reset patience countdown to the maximum.

Return type:

None

class Objective[source]

Bases: ObjectiveProtocol[Output, Target]

Abstract base class for metrics or losses.

Initialize.

compute() dict[str, Tensor][source]

Return the aggregated objective value(s).

Despite the name, which follows common practice, this method caches previous computed values and returns them if available.

Returns:

A dictionary of computed metric values.

Parameters:

self (Self)

Return type:

dict[str, Tensor]

update(outputs: Output, targets: Target) dict[str, Tensor][source]

Updates the objective’s internal state with new outputs and targets.

Parameters:
  • outputs (Output) – the model outputs.

  • targets (Target) – the ground truth targets.

  • self (Self)

Returns:

A dictionary of the calculated metric values for the current update.

Return type:

dict[str, Tensor]

reset() None[source]

Resets the internal state of the instance.

Parameters:

self (Self)

Return type:

None

abstractmethod calculate(outputs: Output, targets: Target) dict[str, Tensor][source]

Method responsible for the calculations.

Parameters:
  • outputs (Output) – model outputs.

  • targets (Target) – ground truth.

  • self (Self)

Returns:

A dictionary of calculated metric values.

Return type:

dict[str, Tensor]

copy() Self[source]

Create a (deep)copy of self.

Return type:

Self

merge_state(other: Self) None[source]

Merge metric states.

Parameters:
  • other (Self) – metric to be merged with.

  • self (Self)

Return type:

None

sync() None[source]

Synchronize metric states across processes.

Parameters:

self (Self)

Return type:

None

__deepcopy__(memo: dict[int, Any]) Self[source]

Deep copy magic method.

Parameters:

memo (dict[int, Any]) – dictionary of already copied objects.

Returns:

A deep copy of the object.

Return type:

Self

compute_metrics(calculator: ObjectiveProtocol[Any, Any]) Mapping[str, float][source]

Compute and represent the metrics as a mapping of named values.

Parameters:

calculator (ObjectiveProtocol[Any, Any]) – An ObjectiveProtocol instance from which to compute metrics.

Returns:

A mapping of metric names to their float values.

Return type:

Mapping[str, float]