"""Module containing functions for nested containers."""
import copy
import dataclasses
from collections.abc import Callable, MutableMapping, MutableSequence
from typing import TYPE_CHECKING, Any, TypeVar, overload
if TYPE_CHECKING:
from _typeshed import DataclassInstance
import torch
from drytorch.core import exceptions
__all__ = [
'apply',
'apply_cpu_detach',
'apply_to',
]
_T = TypeVar('_T')
_C = TypeVar('_C')
if TYPE_CHECKING:
_D = TypeVar('_D', bound=DataclassInstance)
else:
_D = TypeVar('_D')
_MISSING = object()
@overload
def recursive_apply(
obj: _C, expected_type: type[_T], func: Callable[[_T], _T]
) -> _C: ...
@overload
def recursive_apply(
obj: _T, expected_type: type[_T], func: Callable[[_T], _T]
) -> _T: ...
def recursive_apply(
obj: _C, expected_type: type[_T], func: Callable[[_T], _T]
) -> _C | _T:
"""Look for an expected type and apply a given function.
The implementation is similar to default_convert in
github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/collate.py.
It makes a deepcopy of a MutableMapping or MutableSequence container and
modifies the elements of the expected type using the functions or act
recursively on other containers. If obj is a namedtuple, the
function uses the class constructor to create a new instance with the
modified elements. Note that when applied after default_convert, the only
objects of type tuple are namedtuple classes.
Args:
obj: a container containing the expected objects and other containers.
expected_type: the type of the objects to modify.
func: a function that modifies objects of the expected type.
Returns:
The modified object or a copy containing the modified objects.
Raises:
FuncNotApplicableError: if the object is of an unexpected type.
NamedTupleOnlyError: if the attempt of copying a tuple failed.
"""
if isinstance(obj, expected_type):
return func(obj)
if isinstance(obj, MutableMapping):
mapping = copy.copy(obj)
mapping.update(
{
key: recursive_apply(item, expected_type, func)
for key, item in obj.items()
}
)
return mapping # type: ignore
if isinstance(obj, MutableSequence):
sequence = copy.copy(obj)
for i, value in enumerate(obj):
sequence[i] = recursive_apply(value, expected_type, func)
return sequence # type: ignore
if isinstance(obj, tuple):
new = (recursive_apply(item, expected_type, func) for item in obj)
if obj.__class__ is tuple:
return obj.__class__(new) # type: ignore
try:
return obj.__class__(*new) # type: ignore
except TypeError as te:
raise exceptions.NamedTupleOnlyError(obj.__class__.__name__) from te
raise exceptions.FuncNotApplicableError(
func.__name__, obj.__class__.__name__
)
def _dataclass_apply(
obj: _D, expected_type: type[_T], func: Callable[[_T], _T]
) -> _D:
"""Apply func recursively to all fields of a dataclass."""
values: dict[str, Any] = {}
for f in dataclasses.fields(obj):
value = getattr(obj, f.name, _MISSING)
if value is _MISSING:
continue
if f.init:
values[f.name] = recursive_apply(
value, expected_type=expected_type, func=func
)
return dataclasses.replace(obj, **values)
[docs]
def apply(obj: _C, expected_type: type[_T], func: Callable[[_T], _T]) -> _C:
"""Extend recursive_apply supports.
If the input has attributes, it calls recursive_apply, creates a new
instance and sets the attributes of a new instance to the new values.
Args:
obj: container or class containing other containers and tensors.
expected_type: the type of the objects to modify.
func: a function that modifies objects of the expected type.
Returns:
The container or class with the modified objects.
"""
if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
try:
return _dataclass_apply(obj, expected_type, func)
except (TypeError, AttributeError):
pass
dict_attr: dict[str, Any] = {}
if hasattr(obj, '__dict__'):
dict_attr.update(obj.__dict__)
if slots := getattr(obj, '__slots__', None):
for key in slots:
try:
dict_attr[key] = getattr(obj, key)
except AttributeError: # slotted attributes may not be initialized
pass
if dict_attr:
obj_copy = copy.copy(obj)
for key, value in dict_attr.items():
setattr(
obj_copy,
key,
recursive_apply(value, expected_type=expected_type, func=func),
)
return obj_copy
return recursive_apply(obj, expected_type=expected_type, func=func)
[docs]
def apply_to(obj: _C, device: torch.device) -> _C:
"""Change the device of tensors inside a container.
Args:
obj: container or class containing other containers and tensors.
device: the target device.
Returns:
the same container with the tensor on the target device.
"""
non_blocking = device != torch.device('cpu')
def _to_device(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(device, non_blocking=non_blocking)
return apply(obj, expected_type=torch.Tensor, func=_to_device)
[docs]
def apply_cpu_detach(obj: _C) -> _C:
"""Detach and store in cpu the tensors inside a container.
Args:
obj: container or class containing other containers and tensors.
Returns:
the same obj with the tensor on cpu.
"""
def _cpu_detach(tensor: torch.Tensor) -> torch.Tensor:
return tensor.detach().cpu()
return apply(obj, expected_type=torch.Tensor, func=_cpu_detach)