"""Module containing utilies to ensure compatibility with torchmetrics."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import torch
from drytorch.core import protocols as p
__all__ = [
'from_torchmetrics',
]
if TYPE_CHECKING:
from torchmetrics import metric
_Tensor = torch.Tensor
[docs]
def from_torchmetrics(
torch_metric: metric.CompositionalMetric,
) -> p.LossProtocol[_Tensor, _Tensor]:
"""Returns a wrapper of a CompositionalMetric for integration."""
class _TorchMetricCompositionalMetric(p.LossProtocol[_Tensor, _Tensor]):
"""Wrapper of Compositional Metric reporting internal calculations.
Note that this CompositionalMetric can be used as loss.
"""
name = 'Loss'
def __init__(self, _metric: metric.CompositionalMetric) -> None:
self.metric = _metric
self.metric.sync_on_compute = False
self.metric.dist_sync_on_step = False
return
def compute(self) -> dict[str, _Tensor]:
"""Output a dictionary of metric values for each component."""
dict_output = dict[str, _Tensor](
{'Combined Loss': self.metric.compute()}
)
metric_list = list[type(self.metric.metric_b)]()
metric_list.append(self.metric)
while metric_list:
metric_ = metric_list.pop()
if isinstance(metric_, self.metric.__class__):
metric_list.extend([metric_.metric_b, metric_.metric_a])
elif (
isinstance(metric_, float | int | _Tensor)
or metric_ is None
):
continue
else:
if isinstance(value := metric_.compute(), _Tensor):
dict_output[metric_.__class__.__name__] = value
return dict_output
def forward(self, outputs: _Tensor, targets: _Tensor) -> _Tensor:
return self.metric(outputs, targets)
def reset(self) -> Any:
self.metric.reset()
def update(self, outputs: _Tensor, targets: _Tensor) -> Any:
self.metric.update(outputs, targets)
return _TorchMetricCompositionalMetric(torch_metric)