drytorch.lib.objectives
Module containing classes to create and combine loss and metrics.
The interface is similar to https://github.com/Lightning-AI/torchmetrics, with stricter typing and simpler construction. MetricCollection and CompositionalMetric from torchmetrics change their state; here a functional approach is preferred.
Functions
|
Check the metrics returned by the calculator are on the given device. |
|
Compute and represent the metrics as a mapping of named values. |
|
Apply the given tensor callables to the provided outputs and targets. |
Classes
Class defining the default aggregation. |
|
|
Loss resulting from an operation between other two losses. |
|
Loss resulting from adding an extra metric to a loss. |
|
Wrapper that joins two metrics. |
|
Subclass for simple losses with a convenient constructor. |
|
Collection of metrics, one of which serves as a loss. |
|
Subclass for a single metr. |
|
A collection of multiple metrics. |
|
Handle metric value tracking and improvement detection. |
Abstract base class for metrics or losses. |
- class AverageObjective[source]
Bases:
Objective[Output,Target]Class defining the default aggregation.
Initialize.
- class CompositionalLoss(criterion: Callable[[dict[str, Tensor]], Tensor], *, name='Loss', higher_is_better: bool, formula: str = '', **named_fn: Callable[[Output, Target], Tensor])[source]
Bases:
LossBase[Output,Target]Loss resulting from an operation between other two losses.
Initialize.
- Parameters:
criterion (Callable[[dict[str, Tensor]], Tensor]) – function extracting a loss value from metric functions.
name (str) – identifier for the loss.
higher_is_better (bool) – True if higher values indicate better performance, False if lower values are better.
formula (str) – string representation of the loss formula.
named_fn (dict[str, Callable[[Output, Target], Tensor]]) – dictionary of named metric functions.
- class JoinLossMetrics(loss: LossBase[Output, Target], objective: Objective[Output, Target], /)[source]
Bases:
JoinMetrics,LossProtocol[Output,Target]Loss resulting from adding an extra metric to a loss.
Preferably, use
LossBase.watch()when the metric is aMetricCollection().Initialize.
- Parameters:
- class JoinMetrics(metric_a: Objective[Output, Target], metric_b: Objective[Output, Target], /)[source]
Bases:
ObjectiveProtocol[Output,Target]Wrapper that joins two metrics.
Preferably, use
MetricCollection.__or__()when both classes areMetricCollection().Initialize.
- Parameters:
- sync() None[source]
Synchronize metric states across processes.
- Parameters:
self (Self)
- Return type:
None
- class Loss(fn: Callable[[Output, Target], Tensor], /, name: str, higher_is_better: bool = False)[source]
Bases:
CompositionalLoss[Output,Target]Subclass for simple losses with a convenient constructor.
Initialize.
- class LossBase(criterion: Callable[[dict[str, Tensor]], Tensor], name: str, higher_is_better: bool = False, formula: str = '', **named_fn: Callable[[Output, Target], Tensor])[source]
Bases:
MetricCollection[Output,Target],LossProtocol[Output,Target]Collection of metrics, one of which serves as a loss.
- criterion
logic extracting a loss value from computed value.
- Type:
collections.abc.Callable[[dict[str, torch.Tensor]], torch.Tensor]
Initialize.
- Parameters:
criterion (Callable[[dict[str, Tensor]], Tensor]) – logic extracting a loss value from computed value.
name (str) – identifier for the loss.
higher_is_better (bool) – True if higher values indicate better performance, False if lower values are better.
formula (str) – string representation of the loss formula.
**named_fn (dict[str, Callable[[Output, Target], Tensor]]) – dictionary of named functions to calculate.
- forward(outputs: Output, targets: Target) Tensor[source]
Performs a forward pass, updates metrics, and computes the loss.
- Parameters:
outputs (Output) – the model outputs.
targets (Target) – the ground truth targets.
- Returns:
The computed loss value.
- Return type:
- watch(metric: MetricCollection[Output, Target]) None[source]
Include another Objective class in its metrics.
- Parameters:
metric (MetricCollection[Output, Target]) – the other Objective to watch.
- Return type:
None
- __neg__() CompositionalLoss[Output, Target][source]
Constructor from an existing template.
- Returns:
A new CompositionalLoss representing the negated loss.
- Return type:
CompositionalLoss[Output, Target]
- __add__(other: LossBase[Output, Target] | float) CompositionalLoss[Any, Any][source]
Constructor from exiting templates.
- __radd__(other: float) CompositionalLoss[Any, Any][source]
Constructor from exiting templates.
- Parameters:
other (float) – the float to add to the loss.
- Returns:
A new CompositionalLoss representing the sum.
- Return type:
- __sub__(other: LossBase[Output, Target] | float) CompositionalLoss[Output, Target][source]
Constructor from exiting templates.
- Parameters:
other (LossBase[Output, Target] | float) – the other loss or float to subtract.
- Returns:
A new CompositionalLoss representing the difference.
- Return type:
CompositionalLoss[Output, Target]
- __rsub__(other: float) CompositionalLoss[Output, Target][source]
Constructor from exiting templates.
- Parameters:
other (float) – the float from which to subtract the loss.
- Returns:
A new CompositionalLoss representing the difference.
- Return type:
CompositionalLoss[Output, Target]
- __mul__(other: LossBase[Output, Target] | float) CompositionalLoss[Output, Target][source]
Constructor from exiting templates.
- Parameters:
other (LossBase[Output, Target] | float) – the other loss or float to multiply by.
- Returns:
A new CompositionalLoss representing the product.
- Return type:
CompositionalLoss[Output, Target]
- __rmul__(other: float) CompositionalLoss[Output, Target][source]
Constructor from exiting templates.
- Parameters:
other (float) – the float to multiply the loss by.
- Returns:
A new CompositionalLoss representing the product.
- Return type:
CompositionalLoss[Output, Target]
- __truediv__(other: LossBase[Output, Target] | float) CompositionalLoss[Output, Target][source]
Constructor from exiting templates.
- Parameters:
other (LossBase[Output, Target] | float) – the other loss or float to divide by.
- Returns:
A new CompositionalLoss representing the quotient.
- Return type:
CompositionalLoss[Output, Target]
- __rtruediv__(other: float) CompositionalLoss[Output, Target][source]
Constructor from exiting templates.
- Parameters:
other (float) – the float to be divided by the loss.
- Returns:
A new CompositionalLoss representing the quotient.
- Return type:
CompositionalLoss[Output, Target]
- __pow__(other: float) CompositionalLoss[Output, Target][source]
Constructor from exiting templates.
- Parameters:
other (float) – the power to raise the loss to.
- Returns:
A new CompositionalLoss representing the result.
- Return type:
CompositionalLoss[Output, Target]
- class Metric(fn: Callable[[Output, Target], Tensor], /, name: str, higher_is_better: bool | None = None)[source]
Bases:
MetricCollection[Output,Target]Subclass for a single metr.
- fun
the callable that computes the metric value.
- Type:
collections.abc.Callable[[drytorch.lib.objectives.Output, drytorch.lib.objectives.Target], torch.Tensor]
Initialize.
- class MetricCollection(**named_fn: Callable[[Output, Target], Tensor])[source]
Bases:
AverageObjective[Output,Target]A collection of multiple metrics.
- named_fn
dictionary of named functions to calculate.
- Type:
dict[str, collections.abc.Callable[[drytorch.lib.objectives.Output, drytorch.lib.objectives.Target], torch.Tensor]]
Initialize.
- Parameters:
**named_fn (dict[str, Callable[[Output, Target], Tensor]]) – dictionary of named functions to calculate.
- calculate(outputs: Output, targets: Target) dict[str, Tensor][source]
Calculates the values for all metrics in the collection.
- __or__(other: MetricCollection[Output, Target]) MetricCollection[Output, Target][source]
Constructor using existing MetricCollection objects as templates.
This class does not aggregate the states. If you intend to do this, use the merge_state method separately.
- Parameters:
other (MetricCollection[Output, Target]) – another MetricCollection object to combine with.
- Returns:
A new instance containing metrics from both instances.
- Return type:
MetricCollection[Output, Target]
- class MetricTracker(metric_name: str | None = None, min_delta: float = 1e-08, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1))[source]
Bases:
Generic[Output,Target]Handle metric value tracking and improvement detection.
This class is responsible for storing metric history, determining improvements, and managing patience countdown.
Note: this class can be used to automatically modify the training strategy. Therefore, it does not follow the library conventions for a tracker.
- best_is
whether higher or lower values are better.
- Type:
Literal[‘auto’, ‘higher’, ‘lower’]
- filter_fn
function to aggregate recent metric values.
Initialize.
- Parameters:
metric_name (str | None) – name of the metric to track.
min_delta (float) – minimum change required to qualify as an improvement.
patience (int) – number of checks to wait before triggering callback.
best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better.
filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values.
- property best_value: float
Get the best result observed so far.
- Returns:
the best filtered value (according to the ‘best_is’ criterion).
- Raises:
ResultNotAvailableError – if no results have been logged yet.
- property filtered_value: float
Get the current value.
- Returns:
the current value aggregated from recent ones.
- Raises:
ResultNotAvailableError – if no results have been logged yet.
- add_value(value: float) None[source]
Add a new metric value to the history.
- Parameters:
value (float) – the metric value to add.
- Return type:
None
- is_better(value: float, reference: float) bool[source]
Determine if the value is better than a reference value.
When best_is is in ‘auto’ mode, it is assumed that the given value is better than the first recorded one.
- class Objective[source]
Bases:
ObjectiveProtocol[Output,Target]Abstract base class for metrics or losses.
Initialize.
- compute() dict[str, Tensor][source]
Return the aggregated objective value(s).
Despite the name, which follows common practice, this method caches previous computed values and returns them if available.
- update(outputs: Output, targets: Target) dict[str, Tensor][source]
Updates the objective’s internal state with new outputs and targets.
- reset() None[source]
Resets the internal state of the instance.
- Parameters:
self (Self)
- Return type:
None
- abstractmethod calculate(outputs: Output, targets: Target) dict[str, Tensor][source]
Method responsible for the calculations.