"""Module containing classes for the evaluation of a model."""
from typing import Any, TypeVar
import torch
from typing_extensions import override
from drytorch.core import log_events
from drytorch.core import protocols as p
from drytorch.lib import runners
__all__ = [
'Diagnostic',
'EvaluationMixin',
'Test',
'Validation',
]
Input = TypeVar('Input', bound=p.InputType)
Target = TypeVar('Target', bound=p.TargetType)
Output = TypeVar('Output', bound=p.OutputType)
[docs]
class EvaluationMixin(p.MonitorProtocol):
"""Mixin for running inference in eval mode without gradients."""
[docs]
@torch.inference_mode()
def __call__(self, store_outputs: bool = False) -> None:
"""Set the model in evaluation mode and PyTorch in inference mode."""
self.model.module.eval()
super().__call__(store_outputs) # type: ignore
return
[docs]
class Diagnostic(
EvaluationMixin,
runners.ModelRunnerWithLogs[Input, Target, Output, Any],
):
"""Evaluate the model on inference mode without logging the metrics.
Attributes:
model: the model containing the weights to evaluate.
loader: provides inputs and targets in batches.
objective: processes the model outputs and targets.
outputs_list: list of optionally stored outputs.
"""
[docs]
class Validation(
EvaluationMixin,
runners.ModelRunnerWithLogs[Input, Target, Output, Any],
):
"""Evaluate model on inference mode.
It could be used for testing (see subclass) or validating a model.
Attributes:
model: the model containing the weights to evaluate.
loader: provides inputs and targets in batches.
objective: processes the model outputs and targets.
outputs_list: list of optionally stored outputs.
"""
def __init__(
self,
model: p.ModelProtocol[Input, Output],
name: str = '',
*,
loader: p.LoaderProtocol[tuple[Input, Target]],
metric: p.ObjectiveProtocol[Output, Target],
) -> None:
"""Initialize.
Args:
model: the model containing the weights to evaluate.
name: the name for the object for logging purposes.
Defaults to class name plus eventual counter.
loader: provides inputs and targets in batches.
metric: metric to evaluate the model.
"""
super().__init__(model, loader=loader, name=name, objective=metric)
return
[docs]
class Test(Validation[Input, Target, Output]):
"""Evaluate model performance on a test dataset.
Attributes:
model: the model containing the weights to evaluate.
loader: provides inputs and targets in batches.
objective: processes the model outputs and targets.
outputs_list: list of optionally stored outputs.
"""
[docs]
@override
def __call__(self, store_outputs: bool = False) -> None:
"""Test the model on the dataset.
Args:
store_outputs: whether to store model outputs. Defaults to False.
"""
log_events.StartTestEvent(self.name, self.model.name)
super().__call__(store_outputs)
log_events.EndTestEvent(self.name, self.model.name)
return