"""Module containing utilities to extract readable representations.
Attributes:
MAX_DEPTH: Max number of recursions when representing an object.
MAX_REPR_SIZE: Max representation size for iterators and arrays.
INCLUDE_PROPERTIES: Whether to evaluate properties and represent them.
"""
from __future__ import annotations
import dataclasses
import datetime
import functools
import itertools
import math
import numbers
import types
from collections.abc import Hashable, Iterable
from itertools import count
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias
import numpy as np
import torch
from typing_extensions import override
__all__ = [
'INCLUDE_PROPERTIES',
'MAX_DEPTH',
'MAX_REPR_SIZE',
'recursive_repr',
]
if TYPE_CHECKING:
import numpy.typing as npt
from pandas.core.generic import NDFrame
GenericDict: TypeAlias = dict[Hashable, Any]
GenericList: TypeAlias = list[Any]
GenericSet: TypeAlias = set[Any]
GenericTuple: TypeAlias = tuple[Any, ...]
ndarray: TypeAlias = npt.NDArray[Any]
else:
from numpy import ndarray
GenericList = list
GenericDict = dict
GenericSet = set
GenericTuple = tuple
MAX_DEPTH: int = 10
MAX_REPR_SIZE: int = 10
INCLUDE_PROPERTIES: bool = False
class CreatedAtMixin:
"""Mixin saving instantiation timestamp.
Attributes:
ts_fmt: timestamp format string.
"""
ts_fmt: ClassVar[str] = '%Y-%m-%d@%Hh%Mm%Ss'
_created_at: datetime.datetime
def __init__(self, *args, **kwargs) -> None:
"""Initialize."""
self._created_at = datetime.datetime.now()
super().__init__(*args, **kwargs)
@property
def created_at(self) -> datetime.datetime:
"""Read-only timestamp."""
return self._created_at
@property
def created_at_str(self) -> str:
"""Read-only timestamp."""
return self._created_at.strftime(self.ts_fmt)
class DefaultName:
"""Add a counter to a prefix.
Attributes:
_prefixes: dictionary mapping prefixes to their counters.
"""
_prefixes: dict[str, count[int]]
def __init__(self) -> None:
"""Initialize."""
self._prefixes = {}
def __get__(self, instance: Any, objtype: type | None = None) -> str:
"""Return the default name for the instance or class."""
return instance.__name
def __set__(self, instance: Any, value: str) -> None:
"""Return the default name for the instance or class."""
value = value if value else instance.__class__.__name__
count_iter = self._prefixes.setdefault(value, itertools.count())
if count_value := next(count_iter):
value = f'{value}_{count_value}'
instance.__name = value
return
class LiteralStr(str):
"""YAML will attempt to use the pipe style for this class."""
@override
def __add__(self, other: str | LiteralStr) -> LiteralStr:
out = super().__add__(other)
return LiteralStr(out)
@dataclasses.dataclass(frozen=True)
class Omitted:
"""Represent omitted values in an iterable object.
Attributes:
count: how many elements have been omitted. Defaults to NAN (unknown).
"""
count: float = math.nan
[docs]
@functools.singledispatch
def recursive_repr(obj: object, *, depth: int | None = None) -> Any:
"""Create a hierarchical representation of an object.
It recursively represents each attribute of the object or the contained
items in tuples, lists, sets, and dictionaries. The latter structures are
limited in size by limiting the number of elements and replacing the others
with an Omitted instance. Arrays are represented using native representation
Numbers are returned as they are or converted to built-in types.
Args:
obj: The object to represent
depth: Maximum recursion depth allowed
Returns:
A readable representation of the object
"""
if depth is None:
depth = MAX_DEPTH
class_name = obj.__class__.__name__
if depth > 0:
attributes = _get_object_attributes(obj)
result_attrs = {}
for key, value in attributes.items():
if _should_skip_attribute(key, value, obj):
continue
result_attrs[key] = recursive_repr(value, depth=depth - 1)
if result_attrs:
return {'class': class_name, **dict(sorted(result_attrs.items()))}
return repr(obj) if _has_own_repr(obj) else class_name
@recursive_repr.register
def _(obj: Omitted, *, depth: int = 10) -> Omitted:
_not_used = depth
return obj
@recursive_repr.register
def _(obj: str, *, depth: int = 10) -> str:
_not_used = depth
return obj
@recursive_repr.register
def _(obj: None, *, depth: int = 10) -> None:
_not_used = depth
return obj
@recursive_repr.register
def _(obj: numbers.Number, *, depth: int = 10) -> numbers.Number:
if item_method := getattr(obj, 'item', None):
try:
obj = item_method()
except (TypeError, NotImplementedError):
pass
_not_used = depth
return obj
@recursive_repr.register
def _(obj: GenericTuple, *, depth: int = 1) -> tuple[Any, ...]:
return tuple(
recursive_repr(item, depth=depth - 1) for item in _limit_size(obj)
)
@recursive_repr.register
def _(obj: GenericList, *, depth: int = 10) -> list[Any]:
return [recursive_repr(item, depth=depth - 1) for item in _limit_size(obj)]
@recursive_repr.register
def _(obj: GenericSet, *, depth: int = 10) -> set[Hashable]:
return {recursive_repr(item, depth=depth - 1) for item in _limit_size(obj)}
@recursive_repr.register
def _(obj: GenericDict, *, depth: int = 10) -> dict[str, Any]:
out_dict: dict[str, Any] = {
str(key): recursive_repr(value, depth=depth - 1)
for key, value in list(obj.items())[:MAX_REPR_SIZE]
}
if len(obj) > MAX_REPR_SIZE:
out_dict['...'] = Omitted(len(obj) - MAX_REPR_SIZE)
return out_dict
@recursive_repr.register
def _(obj: torch.Tensor, *, depth: int = 10) -> LiteralStr:
_not_used = depth
return recursive_repr(obj.detach().cpu().numpy())
@recursive_repr.register
def _(obj: ndarray, *, depth: int = 10) -> LiteralStr:
size_factor = 2 ** (+obj.ndim - 1)
size_str = f'Array of size {obj.shape}\n'
with np.printoptions(
precision=3,
suppress=True,
threshold=MAX_REPR_SIZE // size_factor,
edgeitems=MAX_REPR_SIZE // (size_factor * 2),
):
_not_used = depth
return LiteralStr(size_str) + LiteralStr(obj)
@recursive_repr.register(type)
def _(obj, *, depth: int = 10) -> str:
_not_used = depth
return obj.__name__
@recursive_repr.register(types.FunctionType)
def _(obj, *, depth: int = 10) -> str:
_not_used: int = depth
return obj.__name__
def _get_object_attributes(obj: object) -> dict[str, Any]:
"""Extract all relevant attributes from an object."""
# Get instance attributes
attributes = getattr(obj, '__dict__', {}).copy()
# Add slot attributes if no __dict__
if not attributes:
slot_names = getattr(obj, '__slots__', [])
attributes = {name: getattr(obj, name, None) for name in slot_names}
# Add properties if enabled
if INCLUDE_PROPERTIES:
attributes.update(_get_object_properties(obj))
return attributes
def _get_object_properties(obj: object) -> dict[str, Any]:
"""Extract property values from an object."""
properties = {}
for cls in reversed(obj.__class__.__mro__):
for name, attr in cls.__dict__.items():
if isinstance(attr, property):
try:
properties[name] = getattr(obj, name)
except Exception as e:
properties[name] = str(e)
return properties
def _has_own_repr(obj: Any) -> bool:
"""Indicate whether __repr__ has been overridden."""
repr_str = repr(obj).lower() # Windows use the upper case
hex_id = hex(id(obj))[2:] # remove '0x' prefix
return not repr_str.endswith(hex_id + '>')
def _limit_size(container: Iterable[Any]) -> list[Any]:
"""Limit the size of iterables and adds an Omitted object."""
# prevents infinite iterators
if hasattr(container, '__len__'):
listed = list(container)
if len(listed) > MAX_REPR_SIZE:
omitted = [Omitted(len(listed) - MAX_REPR_SIZE)]
listed = (
listed[: MAX_REPR_SIZE // 2]
+ omitted
+ listed[-MAX_REPR_SIZE // 2 :]
)
else:
listed = []
iter_container = iter(container)
for _ in range(MAX_REPR_SIZE):
try:
value = next(iter_container)
listed.append(value)
except StopIteration:
break
else:
listed.append([Omitted()])
return listed
def _should_skip_attribute(key: str, value: Any, parent_obj: object) -> bool:
"""Determine if an attribute should be skipped during representation."""
if key.startswith('_'):
return True
if value is parent_obj or value is None:
return True
if hasattr(value, '__len__'):
try:
len_value = len(value)
except (TypeError, NotImplementedError):
...
else:
return len_value == 0
return False
try:
import pandas as pd
except (ImportError, ModuleNotFoundError):
pass
else:
from pandas.core.generic import NDFrame
class PandasPrintOptions:
"""Context manager to temporarily set Pandas display options.
Args:
precision: number of digits of precision for floating point output.
max_rows: maximum number of rows to display.
max_columns: maximum number of columns to display.
"""
_options: dict[str, int]
_original_options: dict[str, Any]
def __init__(
self, precision: int = 3, max_rows: int = 10, max_columns: int = 10
) -> None:
"""Initialize.
Args:
precision: see Pandas docs.
max_rows: see Pandas docs.
max_columns: see Pandas docs.
"""
self._options = {
'display.precision': precision,
'display.max_rows': max_rows,
'display.max_columns': max_columns,
}
self._original_options = {}
return
def __enter__(self) -> None:
"""Temporarily modify settings."""
self._original_options.update(
{key: pd.get_option(key) for key in self._options}
)
for key, value in self._options.items():
pd.set_option(key, value)
return
def __exit__(
self,
exc_type: None = None,
exc_val: None = None,
exc_tb: None = None,
) -> None:
"""Restore original settings."""
for key, value in self._original_options.items():
pd.set_option(key, value)
return
@recursive_repr.register
def _(obj: NDFrame, *, depth: int = 10) -> LiteralStr:
# only called when Pandas is imported
_not_used = depth
with PandasPrintOptions(
max_rows=MAX_REPR_SIZE, max_columns=MAX_REPR_SIZE
):
return LiteralStr(obj)