"""Module containing internal exceptions for the drytorch package."""
import pathlib
import traceback
from typing import Any, ClassVar, Final
import torch
__all__ = [
'AccessOutsideScopeError',
'CannotStoreOutputWarning',
'CheckpointNotInitializedError',
'ComputedBeforeUpdatedWarning',
'ComputedMetricsTypeError',
'ConvergenceError',
'DatasetHasNoLengthError',
'DeviceMismatchError',
'DistributedDatasetNotDivisibleWarning',
'DistributedStorageWarning',
'DryTorchError',
'DryTorchWarning',
'EpochNotFoundError',
'ExperimentalFeatureWarning',
'FailedOptionalImportWarning',
'FuncNotApplicableError',
'LossNotScalarError',
'MetricNotFoundError',
'MissingParamError',
'ModelDeviceMismatchError',
'ModelNotFoundError',
'ModuleAlreadyRegisteredError',
'ModuleNotDistributedWarning',
'ModuleNotRegisteredError',
'NameAlreadyRegisteredError',
'NamedTupleOnlyError',
'NestedScopeError',
'NoActiveExperimentError',
'NoPreviousRunsWarning',
'NotExistingRunWarning',
'ObjectiveSyncWarning',
'OptimizerNotLoadedWarning',
'PastEpochWarning',
'RecursionWarning',
'ResultNotAvailableError',
'RunAlreadyCompletedWarning',
'RunAlreadyRecordedError',
'RunAlreadyRunningWarning',
'RunNotRecordedError',
'RunNotStartedWarning',
'TerminatedTrainingWarning',
'TrackerAlreadyRegisteredError',
'TrackerError',
'TrackerExceptionWarning',
'TrackerNotUsedError',
]
[docs]
class DryTorchError(Exception):
"""Base exception class for all drytorch package exceptions."""
_template: ClassVar[str] = ''
def __init__(self, *args: Any) -> None:
"""Initialize.
Args:
*args: arguments to be formatted into the message template.
"""
super().__init__(self._template.format(*args))
[docs]
class DryTorchWarning(UserWarning):
"""Base warning class for all drytorch package warnings."""
_template: ClassVar[str] = ''
def __init__(self, *args: Any) -> None:
"""Initialize.
Args:
*args: arguments to be formatted into the message template.
"""
super().__init__(self._template.format(*args))
[docs]
class TrackerError(DryTorchError):
"""Exception raised by tracker objects during experiment tracking."""
_template = '[{}] {}'
def __init__(self, tracker: Any, tracker_msg: str) -> None:
"""Initialize.
Args:
tracker: the tracker object that encountered the error.
tracker_msg: the error message from the tracker.
"""
self.tracker = tracker
super().__init__(tracker.__class__.__name__, tracker_msg)
[docs]
class AccessOutsideScopeError(DryTorchError):
"""Raised when an operation is attempted outside an experiment scope."""
_template = 'Operation only allowed within an experiment scope.'
[docs]
class CheckpointNotInitializedError(DryTorchError):
"""Raised when attempting to use a checkpoint without a registered model."""
_template = 'The checkpoint did not register any model.'
[docs]
class ComputedMetricsTypeError(DryTorchError):
"""Raised when computed metrics have an unexpected type."""
_template = (
'Expected computed metrics as a Mapping[str, Tensor] or Tensor. Got {}.'
)
def __init__(self, computed_metrics_type: type) -> None:
"""Initialize.
Args:
computed_metrics_type: the actual type of the computed metrics.
"""
self.computed_metrics_type: Final = computed_metrics_type
super().__init__(computed_metrics_type.__name__)
[docs]
class ConvergenceError(DryTorchError):
"""Raised when a module fails to converge during training."""
_template = 'The module did not converge (criterion is {}).'
def __init__(self, criterion: float) -> None:
"""Initialize.
Args:
criterion: the convergence criterion that was not met.
"""
self.criterion: Final = criterion
super().__init__(criterion)
[docs]
class DatasetHasNoLengthError(DryTorchError):
"""Raised when a dataset does not implement the __len__ method."""
_template = 'Dataset does not implement __len__ method.'
[docs]
class DeviceMismatchError(DryTorchError):
"""Raised when the metrics device does not match the expected device."""
_template = 'Metric {} is stored on {} but expected on {}.'
def __init__(
self,
metric_name: str,
metric_device: torch.device,
target_device: torch.device,
) -> None:
"""Initialize.
Args:
metric_name: the name of the metric.
metric_device: the device of the output tensor.
target_device: the device of the model.
"""
self.metric_name: Final = metric_name
self.metric_device: Final = metric_device
self.target_device: Final = target_device
super().__init__(metric_name, metric_device, target_device)
[docs]
class EpochNotFoundError(DryTorchError):
"""Raised when no saved model is found in the checkpoint directory."""
_template = 'No checkpoints for epoch {} found in {}.'
def __init__(self, epoch: int, checkpoint_directory: pathlib.Path) -> None:
"""Initialize.
Args:
epoch: the epoch that was not found.
checkpoint_directory: the directory path where no model was found.
"""
self.model_directory: Final = checkpoint_directory
super().__init__(epoch, checkpoint_directory)
[docs]
class FuncNotApplicableError(DryTorchError):
"""Raised when a function cannot be applied to a specific type."""
_template = 'Cannot apply function {} on type {}.'
def __init__(self, func_name: str, type_name: str) -> None:
"""Initialize.
Args:
func_name: the name of the function that cannot be applied.
type_name: the name of the type that doesn't support the function.
"""
self.func_name: Final = func_name
self.type_name: Final = type_name
super().__init__(func_name, type_name)
[docs]
class LossNotScalarError(DryTorchError):
"""Raised when a loss value is not a scalar tensor."""
_template = 'Loss must be a scalar but got Tensor of shape {}.'
def __init__(self, size: torch.Size) -> None:
"""Initialize.
Args:
size: the actual size of the non-scalar loss tensor.
"""
self.size: Final = size
super().__init__(size)
[docs]
class MetricNotFoundError(DryTorchError):
"""Raised when a requested metric is not found in the specified source."""
_template = 'No metric {}found in {}.'
def __init__(self, source_name: str, metric_name: str) -> None:
"""Initialize.
Args:
source_name: the name of the source where the metric was not found.
metric_name: the name of the metric that was not found.
"""
self.source_name: Final = source_name
self.metric_name: Final = metric_name + ' ' if metric_name else ''
super().__init__(self.metric_name, source_name)
[docs]
class MissingParamError(DryTorchError):
"""Raised when parameter groups are missing required parameters."""
_template = 'Parameter groups in input learning rate miss parameters {}.'
def __init__(
self, module_names: list[str], lr_param_groups: list[str]
) -> None:
"""Initialize.
Args:
module_names: list of module names that should have parameters.
lr_param_groups: group names in the parameter learning rate config.
"""
self.module_names: Final = module_names
self.lr_param_groups: Final = lr_param_groups
self.missing: Final = set(module_names) - set(lr_param_groups)
super().__init__(self.missing)
[docs]
class ModuleAlreadyRegisteredError(DryTorchError):
"""Raised when trying to access a model that has already been registered."""
_template = (
'Module from model {} is already registered in experiment {} run {}.'
)
def __init__(self, model_name: str, exp_name: str, run_id: str) -> None:
"""Initialize.
Args:
model_name: the name of the model that was not registered.
exp_name: the name of the current experiment.
run_id: the current run's id.
"""
self.model_name: Final = model_name
self.exp_name: Final = exp_name
self.run_id: Final = run_id
super().__init__(model_name, exp_name, run_id)
[docs]
class ModuleNotRegisteredError(DryTorchError):
"""Raised an actor tries to access a module that hasn't been registered."""
_template = (
'Module from model {} is not registered in the current run {} - {}.'
)
def __init__(self, model_name: str, exp_name: str, run_id: str) -> None:
"""Initialize.
Args:
model_name: the name of the model that was not registered.
exp_name: the name of the current experiment.
run_id: the current run's id.
"""
self.model_name: Final = model_name
self.exp_name: Final = exp_name
self.run_id: Final = run_id
super().__init__(model_name, exp_name, run_id)
[docs]
class ModelDeviceMismatchError(DryTorchError):
"""Raised when the metrics device does not match the model device."""
_template = (
"In multiprocessing, parameters' and outputs' device type must match."
)
[docs]
class ModelNotFoundError(DryTorchError):
"""Raised when no saved model is found in the checkpoint directory."""
_template = 'No saved module found in {}.'
def __init__(self, checkpoint_directory: pathlib.Path) -> None:
"""Initialize.
Args:
checkpoint_directory: the directory path where no model was found.
"""
self.checkpoint_directory: Final = checkpoint_directory
super().__init__(checkpoint_directory)
[docs]
class NameAlreadyRegisteredError(DryTorchError):
"""Raised when attempting to register a name already in use."""
_template = 'Name {} has already been registered in the current run.'
def __init__(self, name: str) -> None:
"""Initialize.
Args:
name: the name that is already registered.
"""
super().__init__(name)
[docs]
class NamedTupleOnlyError(DryTorchError):
"""Raised when operations require a named tuple and not a subclass."""
_template = (
'The only accepted subtypes of tuple are namedtuple classes. Got {}.'
)
def __init__(self, tuple_type: str) -> None:
"""Initialize.
Args:
tuple_type: the actual type of the tuple that was provided.
"""
self.tuple_type: Final = tuple_type
super().__init__(tuple_type)
[docs]
class NestedScopeError(DryTorchError):
"""Raised when attempting to nest an experiment scope within another one."""
_template = 'Cannot start Experiment {} within Experiment {} scope.'
def __init__(self, current_exp_name: str, new_exp_name: str) -> None:
"""Initialize.
Args:
current_exp_name: the name of the currently active experiment.
new_exp_name: the name of the experiment that cannot be started.
"""
self.current_exp_name: Final = current_exp_name
self.new_exp_name: Final = new_exp_name
super().__init__(current_exp_name, new_exp_name)
[docs]
class NoActiveExperimentError(DryTorchError):
"""Raised when no experiment is currently active."""
_template = 'No experiment {}has been started.'
def __init__(
self,
experiment_name: str | None = None,
experiment_class: type | None = None,
) -> None:
"""Initialize.
Args:
experiment_name: specifies experiment's name.
experiment_class: specifies experiment's name.
"""
self.experiment_class: Final = experiment_class
if experiment_name is not None:
specify_string = f'named {experiment_name} '
elif experiment_class is not None:
specify_string = f'of class {experiment_class.__class__.__name__} '
else:
specify_string = ''
super().__init__(specify_string)
[docs]
class ResultNotAvailableError(DryTorchError):
"""Raised when trying to access a result before the hook has been called."""
_template = (
'The result will be available only after the hook has been called.'
)
[docs]
class TrackerAlreadyRegisteredError(DryTorchError):
"""Raised when attempting to register an already registered tracker."""
_template = 'Tracker {} already registered in experiment {}.'
def __init__(self, tracker_name: str, exp_name: str) -> None:
"""Initialize.
Args:
tracker_name: the name of the tracker that is already registered.
exp_name: the name of the experiment where to register the tracker.
"""
self.tracker_name: Final = tracker_name
super().__init__(tracker_name, exp_name)
[docs]
class TrackerNotUsedError(DryTorchError):
"""Raised when trying to access a tracker that is not registered."""
_template = 'Tracker {} has not been used in the active experiment'
def __init__(self, tracker_name: str) -> None:
"""Initialize.
Args:
tracker_name: the name of the tracker that is not registered.
"""
self.tracker_name: Final = tracker_name
super().__init__(tracker_name)
[docs]
class CannotStoreOutputWarning(DryTorchWarning):
"""Warning raised when output cannot be stored due to an error."""
_template = 'Impossible to store output because the following error.\n{}'
def __init__(self, error: BaseException) -> None:
"""Initialize.
Args:
error: the error that prevented output storage.
"""
self.error: Final = error
super().__init__(str(error))
[docs]
class ComputedBeforeUpdatedWarning(DryTorchWarning):
"""Warning raised when compute method is called before updating."""
_template = 'The ``compute`` method of {} was called before its updating.'
def __init__(self, calculator: Any) -> None:
"""Initialize.
Args:
calculator: the calculator object that was computed before updating.
"""
self.calculator: Final = calculator
super().__init__(calculator.__class__.__name__)
[docs]
class DistributedDatasetNotDivisibleWarning(DryTorchWarning):
"""Warning raised when the dataset cannot be equally distributed."""
_template = (
'{} has encountered the following issue with distributed evaluation: \n'
'The dataset size: {} is not divisible by the number of processes: {}. '
'Some samples will be evaluated twice, and metrics may not be reliable.'
)
def __init__(self, name: str, len_dataset: int, n_processes: int) -> None:
"""Initialize.
Args:
name: the name of the actor experiencing the issue.
len_dataset: the size of the dataset.
n_processes: the number of processes used in distributed processing.
"""
self.actor: Final = name
self.dataset_size: Final = len_dataset
self.num_processes: Final = n_processes
super().__init__(name, len_dataset, n_processes)
[docs]
class DistributedStorageWarning(DryTorchWarning):
"""Warning raised when the distributed storage is not synchronized."""
_template = 'The storage of the distributed model is not synchronized:\n{}.'
def __init__(self, error: BaseException) -> None:
"""Initialize.
Args:
error: the error that occurred while synchronizing the storage.
"""
self.error: Final = error
super().__init__(str(error))
[docs]
class ExperimentalFeatureWarning(DryTorchWarning):
"""Warning raised when an experimental feature is used."""
_template = '{} is an experimental feature and may change in the future.'
def __init__(self, feature: str) -> None:
"""Initialize.
Args:
feature: the experimental feature that was used.
"""
self.feature: Final = feature
super().__init__(feature)
[docs]
class FailedOptionalImportWarning(DryTorchWarning):
"""Warning raised when an optional dependency fails to import."""
_template = (
'Failed to import optional dependency {}. Install for better support.'
)
def __init__(self, package_name: str) -> None:
"""Initialize.
Args:
package_name: the name of the package that failed to import.
"""
self.package_name: Final = package_name
super().__init__(package_name)
[docs]
class ModuleNotDistributedWarning(DryTorchWarning):
"""Warning raised when a model is not distributed."""
_template = 'Distributed wrapper not detected: model weights may diverge.'
[docs]
class NoPreviousRunsWarning(DryTorchWarning):
"""Attempted to resume the last run, but none were found."""
_template = 'No previous runs found. Starting a new one.'
[docs]
class NotExistingRunWarning(DryTorchWarning):
"""Attempted to resume a not existing run."""
_template = 'Run with id {} not found. Starting a new one.'
def __init__(self, run_id: str) -> None:
"""Initialize.
Args:
run_id: the id of the run that was not found.
"""
self.run_id: Final = run_id
super().__init__(run_id)
[docs]
class ObjectiveSyncWarning(DryTorchWarning):
"""Warning for metric synchronization configuration issues."""
_template = (
'Objective synchronization encountered issue: {}. Recommend to: {} .'
)
def __init__(self, issue: str, recommend: str) -> None:
"""Initialize.
Args:
issue: the issue that was encountered with the objective.
recommend: the recommended action to fix the issue.
"""
self.issue: Final = issue
self.recommend: Final = recommend
super().__init__(issue, recommend)
[docs]
class OptimizerNotLoadedWarning(DryTorchWarning):
"""Warning raised when the optimizer has not been correctly loaded."""
_template = 'The optimizer has not been correctly loaded:\n{}'
def __init__(self, error: BaseException) -> None:
"""Initialize.
Args:
error: the error that occurred while loading the optimizer.
"""
self.error: Final = error
super().__init__(error)
[docs]
class PastEpochWarning(DryTorchWarning):
"""Warning raised when training is requested for a past epoch."""
_template = 'Training until epoch {} stopped: current epoch is already {}.'
def __init__(self, selected_epoch: int, current_epoch: int) -> None:
"""Initialize.
Args:
selected_epoch: the epoch that training was requested until.
current_epoch: the current epoch number.
"""
self.selected_epoch: Final = selected_epoch
self.current_epoch: Final = current_epoch
super().__init__(selected_epoch, current_epoch)
[docs]
class RecursionWarning(DryTorchWarning):
"""Warning raised when recursive objects obstruct metadata extraction."""
_template = (
'Impossible to extract metadata because there are recursive objects.'
)
[docs]
class RunAlreadyRecordedError(DryTorchError):
"""Error raised when attempting to record a run multiple times."""
_template = (
'Run {} already recorded in experiment {}. Use resume=True to resume.'
)
def __init__(self, run_id: str, exp_name: str) -> None:
"""Initialize.
Args:
run_id: the id of the run that is already recorded.
exp_name: the name of the experiment where to record the run.
"""
self.run_id: Final = run_id
self.exp_name: Final = exp_name
super().__init__(run_id, exp_name)
[docs]
class RunAlreadyCompletedWarning(DryTorchWarning):
"""Warning raised when a run is stopped after completion."""
_template = (
"""Attempted to stop a Run instance that is already completed."""
)
[docs]
class RunAlreadyRunningWarning(DryTorchWarning):
"""Warning raised when a run is started when already running."""
_template = """Attempted to start a Run instance that is already running."""
[docs]
class RunNotStartedWarning(DryTorchWarning):
"""Warning raised when a run is stopped before being started."""
_template = """Attempted to stop a Run instance that is not active."""
[docs]
class RunNotRecordedError(DryTorchError):
"""Raised when attempting to update a run that is not registered."""
_template = 'Run with id {} is not recorded.'
def __init__(self, run_id: str) -> None:
"""Constructor.
Args:
run_id: the id of the run that is not registered.
"""
self.run_id: Final = run_id
super().__init__(run_id)
[docs]
class TerminatedTrainingWarning(DryTorchWarning):
"""Warning raised when training is attempted after termination."""
_template = 'Attempted to train module after termination.'
[docs]
class TrackerExceptionWarning(DryTorchWarning):
"""Warning raised when a tracker encounters an error and is skipped."""
_template = (
'Tracker {} encountered the following error and was skipped:\n{}'
)
def __init__(self, subscriber_name: str, error: BaseException) -> None:
"""Constructor.
Args:
subscriber_name: the name of the tracker that encountered the error.
error: the error that occurred in the tracker.
"""
self.subscriber_name: Final = subscriber_name
self.error: Final = error
formatted_traceback: Final = traceback.format_exc()
super().__init__(subscriber_name, formatted_traceback)