"""Support for torcheval syncing recommended for distributed training."""
import torch
from torcheval import metrics
from torcheval.metrics import toolkit
from drytorch.core import protocols as p
_Tensor = torch.Tensor
[docs]
def from_torcheval(
torch_eval: metrics.Metric[_Tensor | dict[str, _Tensor]],
) -> p.ObjectiveProtocol[_Tensor, _Tensor]:
"""Returns a wrapper of a Metric from torcheval with a sync method."""
class _TorchEvalWithSync(p.ObjectiveProtocol[_Tensor, _Tensor]):
name = 'Loss'
def __init__(
self, _metric: metrics.Metric[_Tensor | dict[str, _Tensor]]
) -> None:
self.metric = _metric
return
def compute(self) -> _Tensor | dict[str, _Tensor]:
return self.metric.compute()
def reset(self) -> None:
self.metric.reset()
return
def sync(self) -> None:
"""Use torcheval toolkit to synchronize and compute metrics."""
toolkit.sync_and_compute(self.metric)
return
def update(self, outputs: _Tensor, targets: _Tensor) -> None:
self.metric.update(outputs, targets)
return
return _TorchEvalWithSync(torch_eval)