drytorch.lib.hooks

Module containing registry, callbacks, and hooks for a Trainer.

Functions

call_every(interval[, start])

Create a transformer for periodic hook execution.

static_hook_class(cls)

Class decorator to wrap a callable class into a static hook type.

Classes

CallEvery(wrapped[, parameters, interval, start])

Metadata-aware wrapper for periodic execution.

ChangeSchedulerOnPlateauCallback([metric, ...])

Change the learning rate schedule when a metric has stopped improving.

EarlyStoppingCallback([metric, monitor, ...])

Implement early stopping logic for training models.

Hook(wrapped[, base_hook_name, parameters])

Wrapper for callable taking a Trainer as input.

HookRegistry()

A registry for managing and executing hooks.

MetricExtractor([metric, monitor])

Handle extraction of metrics from trainer/validation protocols.

MetricMonitor([metric, monitor, min_delta, ...])

Handle metric monitoring and alerts when performance stops increasing.

PruneCallback(thresholds[, metric, monitor, ...])

Implement pruning logic for training models.

ReduceLROnPlateau([metric, monitor, ...])

Reduce the learning rate when a metric has stopped improving.

RestartScheduleOnPlateau([metric, monitor, ...])

Restart the scheduling after plateauing.

StaticHook(wrapped[, base_hook_name, parameters])

Ignoring arguments and execute a wrapped function.

TrainerHook()

Compose control flow on side effects while keeping track of parameters.

class CallEvery(wrapped: Callable[[TrainerProtocol[Input, Target, Output]], None], parameters: dict[str, Any] | None = None, interval: int = 1, start: int = 0)[source]

Bases: Hook[Input, Target, Output]

Metadata-aware wrapper for periodic execution.

parameters

metadata associated with the hook.

Type:

dict[str, Any]

Initialize.

Parameters:
  • wrapped (Callable[[p.TrainerProtocol[Input, Target, Output]], None]) – the function to be wrapped and called statically.

  • parameters (dict[str, Any]) – metadata associated with the hook.

  • interval (int) – the frequency of calling the function.

  • start (int) – the epoch to start calling the function.

__call__(trainer: TrainerProtocol[Input, Target, Output]) None[source]

Optionally, calling wrapped callable.

Parameters:

trainer (TrainerProtocol[Input, Target, Output])

Return type:

None

class ChangeSchedulerOnPlateauCallback(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), cooldown: int = 0)[source]

Bases: Generic[Output, Target]

Change the learning rate schedule when a metric has stopped improving.

monitor

monitor instance.

Type:

drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]

cooldown

number of calls to skip after changing the schedule.

Type:

int

Initialize.

Parameters:
  • metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.

  • monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.

  • min_delta (float) – minimum change required to qualify as an improvement.

  • patience (int) – number of checks to wait before changing the schedule.

  • best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better. Default ‘auto’ will determine this from initial measurements.

  • filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values. Default gets the last value.

  • cooldown (int) – calls to skip after changing the schedule.

__call__(instance: TrainerProtocol[Any, Target, Output]) None[source]

Check if there is a plateau and reduce the learning rate if needed.

Parameters:

instance (TrainerProtocol[Any, Target, Output]) – Trainer instance to evaluate.

Return type:

None

abstractmethod get_scheduler(epoch: int, scheduler: SchedulerProtocol) SchedulerProtocol[source]

Modify input scheduler.

Parameters:
Returns:

Modified scheduler.

Return type:

SchedulerProtocol

class EarlyStoppingCallback(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 10, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), start_from_epoch: int = 2)[source]

Bases: Generic[Output, Target]

Implement early stopping logic for training models.

monitor

monitor instance.

Type:

drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]

start_from_epoch

epoch to start monitoring from.

Type:

int

Initialize.

Parameters:
  • metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.

  • monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.

  • min_delta (float) – minimum change required to qualify as an improvement.

  • patience (int) – number of calls to wait before stopping. Default ‘auto’ will determine this from initial measurements.

  • best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better.

  • filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values. Default gets the last value.

  • start_from_epoch (int) – first epoch to start monitoring from.

__call__(instance: TrainerProtocol[Any, Target, Output]) None[source]

Evaluate whether training should be stopped early.

Parameters:

instance (TrainerProtocol[Any, Target, Output]) – Trainer instance to evaluate.

Return type:

None

class Hook(wrapped: Callable[[TrainerProtocol[Input, Target, Output]], None], base_hook_name: str | None = None, parameters: dict[str, Any] | None = None)[source]

Bases: TrainerHook[Input, Target, Output]

Wrapper for callable taking a Trainer as input.

parameters

metadata associated with the hook.

Type:

dict[str, Any]

base_hook_name

name of the base hook for representation.

Type:

str

Initialize.

Parameters:
  • wrapped (Callable[[p.TrainerProtocol[Input, Target, Output]], None]) – the function to be conditionally called.

  • base_hook_name (str) – name of the base hook for representation. Defaults to the class name.

  • parameters (dict[str, Any]) – metadata associated with the hook.

__call__(trainer: TrainerProtocol[Input, Target, Output]) None[source]

Execute the call.

Parameters:

trainer (TrainerProtocol[Input, Target, Output]) – the trainer to pass to the wrapped function.

Return type:

None

class HookRegistry[source]

Bases: Generic[_T_contra]

A registry for managing and executing hooks.

The hooks have a generic object as input and can access it.

hooks

a list of registered hooks.

Type:

list[collections.abc.Callable[[drytorch.lib.hooks._T_contra], None]]

Initialize.

execute(input_object: _T_contra) None[source]

Execute the registered hooks in order of registration.

Parameters:

input_object (_T_contra) – the input to pass to each hook.

Return type:

None

register(hook: Callable[[_T_contra], None]) None[source]

Register a single hook.

Parameters:

hook (Callable[[_T_contra], None]) – the hook to register.

Return type:

None

register_all(hook_list: list[Callable[[_T_contra], None]]) None[source]

Register multiple hooks.

Parameters:

hook_list (list[Callable[[_T_contra], None]]) – the list of hooks to register.

Return type:

None

class MetricExtractor(metric: ObjectiveProtocol[Any, Any] | str | None = None, monitor: MonitorProtocol | None = None)[source]

Bases: object

Handle extraction of metrics from trainer/validation protocols.

This class is responsible for interfacing with trainer and validation protocols to extract metric values.

metric_spec

the metric specification (name or protocol instance).

Type:

drytorch.core.protocols.ObjectiveProtocol[Any, Any] | str | None

optional_monitor

evaluation protocol to monitor.

Type:

drytorch.core.protocols.MonitorProtocol | None

Initialize.

Parameters:
  • metric (p.ObjectiveProtocol[Any, Any] | str | None) – name of the metric to monitor or metric calculator instance.

  • monitor (p.MonitorProtocol | None) – evaluation protocol to monitor.

property metric_name: str | None

Get the resolved metric name.

extract_metric_value(instance: TrainerProtocol[Input, Target, Output], tracker: MetricTracker[Output, Target]) float[source]

Extract and return the metric value from the instance.

Parameters:
  • instance (TrainerProtocol[Input, Target, Output]) – Trainer instance to extract from.

  • tracker (MetricTracker[Output, Target]) – objectives.MetricTracker to potentially update metric name.

Returns:

The extracted metric value.

Raises:

MetricNotFoundError – if the specified metric is not found.

Return type:

float

get_metric_best_is() Literal['auto', 'higher', 'lower'] | None[source]

Get the best_is preference from the metric if available.

Return type:

Literal[‘auto’, ‘higher’, ‘lower’] | None

class MetricMonitor(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1))[source]

Bases: Generic[Output, Target]

Handle metric monitoring and alerts when performance stops increasing.

metric_tracker

handles metric value tracking and improvement detection.

Type:

drytorch.lib.objectives.MetricTracker[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]

extractor

handles metric extraction from protocols.

Type:

drytorch.lib.hooks.MetricExtractor

Initialize.

Parameters:
  • metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of the metric to monitor or metric calculator instance.

  • monitor (p.MonitorProtocol | None) – evaluation protocol to monitor.

  • min_delta (float) – minimum change required to qualify as an improvement.

  • patience (int) – number of checks to wait before triggering callback.

  • best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better.

  • filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values.

property metric_name: str | None

Get the metric name being monitored.

property best_value: float

Get the best result observed so far.

property filtered_value: float

Get the current filtered value.

property history: list[float]

Get the metric history.

is_better(value: float, reference: float) bool[source]

Check whether to be patient.

Parameters:
Return type:

bool

is_improving() bool[source]

Determine if the model performance is improving.

Return type:

bool

is_patient() bool[source]

Check whether to be patient.

Return type:

bool

record_metric_value(instance: TrainerProtocol[Any, Target, Output]) None[source]

Register a new metric value from a monitored evaluation.

Parameters:

instance (TrainerProtocol[Any, Target, Output]) – Trainer instance to extract metric from.

Raises:

MetricNotFoundError – if the specified metric is not found.

Return type:

None

class PruneCallback(thresholds: Mapping[int, float | None], metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1))[source]

Bases: Generic[Output, Target]

Implement pruning logic for training models.

monitor

monitor instance.

Type:

drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]

thresholds

dictionary mapping epochs to pruning thresholds.

Type:

collections.abc.Mapping[int, float | None]

Initialize.

Parameters:
  • thresholds (Mapping[int, float | None]) – dictionary mapping epochs to pruning values.

  • metric (str | p.ObjectiveProtocol[Output, Target] | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.

  • monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.

  • min_delta (float) – minimum change required to qualify as an improvement.

  • best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better. Default ‘auto’ will determine this from initial measurements.

  • filter_fn (Callable[[Sequence[float]], float]) – function to aggregate the intermediate results

  • Default (values.) – gets the last value.

__call__(instance: TrainerProtocol[Any, Target, Output]) None[source]

Evaluate whether training should be stopped early.

Parameters:

instance (TrainerProtocol[Any, Target, Output]) – trainer instance to evaluate.

Return type:

None

class ReduceLROnPlateau(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), factor: float = 0.1, cooldown: int = 0)[source]

Bases: ChangeSchedulerOnPlateauCallback[Output, Target]

Reduce the learning rate when a metric has stopped improving.

monitor

monitor instance.

Type:

drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]

cooldown

number of calls to skip after changing the schedule.

Type:

int

factor

factor by which to reduce the learning rate.

Type:

float

Initialize.

Parameters:
  • metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.

  • monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.

  • min_delta (float) – minimum change required to qualify as an improvement.

  • patience (int) – number of checks to wait before changing the schedule.

  • best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better. Default ‘auto’ will determine this from initial measurements.

  • filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values. Default gets the last value.

  • cooldown (int) – calls to skip after changing the schedule.

  • factor (float) – factor by which to reduce the learning rate.

get_scheduler(epoch: int, scheduler: SchedulerProtocol) SchedulerProtocol[source]

Modify the input scheduler to scale down the learning rate.

Parameters:
Returns:

Modified scheduler.

Return type:

SchedulerProtocol

class RestartScheduleOnPlateau(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), cooldown: int = 0)[source]

Bases: ChangeSchedulerOnPlateauCallback[Output, Target]

Restart the scheduling after plateauing.

monitor

monitor instance.

Type:

drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]

cooldown

number of calls to skip after changing the schedule.

Type:

int

Initialize.

Parameters:
  • metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.

  • monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.

  • min_delta (float) – minimum change required to qualify as an improvement.

  • patience (int) – number of checks to wait before changing the schedule.

  • best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better. Default ‘auto’ will determine this from initial measurements.

  • filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values. Default gets the last value.

  • cooldown (int) – calls to skip after changing the schedule.

get_scheduler(epoch: int, scheduler: SchedulerProtocol) SchedulerProtocol[source]

Consider training until now a warm-up and restart scheduling.

Parameters:
Returns:

Modified scheduler.

Return type:

SchedulerProtocol

class StaticHook(wrapped: Callable[[], None], base_hook_name: str | None = None, parameters: dict[str, Any] | None = None)[source]

Bases: TrainerHook[Any, Any, Any]

Ignoring arguments and execute a wrapped function.

parameters

metadata associated with the hook.

Type:

dict[str, Any]

base_hook_name

name of the base hook for representation.

Type:

str

Initialize.

Parameters:
  • wrapped (Callable[[], None]) – the function to be wrapped and called statically.

  • base_hook_name (str) – name of the base hook for representation. Defaults to the class name.

  • parameters (dict[str, Any]) – metadata associated with the hook.

__call__(trainer: TrainerProtocol[Input, Target, Output]) None[source]

Execute the call.

Parameters:

trainer (TrainerProtocol[Input, Target, Output]) – the trainer to pass to the wrapped function.

Return type:

None

class TrainerHook[source]

Bases: Generic[Input, Target, Output]

Compose control flow on side effects while keeping track of parameters.

Parameters:
  • parameters – metadata associated with the hook

  • base_hook_name – name of the base hook for representation

abstractmethod __call__(trainer: TrainerProtocol[Input, Target, Output]) None[source]

Execute the call.

Parameters:

trainer (TrainerProtocol[Input, Target, Output]) – the trainer to pass to the wrapped function.

Return type:

None

bind(f: Callable[[Callable[[TrainerProtocol[Input, Target, Output]], None]], Hook], /) Hook[Input, Target, Output][source]

Allow transformation of the Hook.

Parameters:

f (Callable[[Callable[[TrainerProtocol[Input, Target, Output]], None]], Hook]) – a function specifying the transformation.

Returns:

the transformed Hook.

Return type:

Hook[Input, Target, Output]

call_every(interval: int, start: int = 0) Callable[[Callable[[TrainerProtocol[Input, Target, Output]], None]], Hook[Input, Target, Output]][source]

Create a transformer for periodic hook execution.

Parameters:
  • interval (int) – the frequency of calling the hook.

  • start (int) – the epoch to start calling the hook.

Returns:

A transformer that adds periodic logic to a callable.

Return type:

Callable[[Callable[[TrainerProtocol[Input, Target, Output]], None]], Hook[Input, Target, Output]]

static_hook_class(cls: Callable[[_P], Callable[[], None]]) Callable[[_P], StaticHook][source]

Class decorator to wrap a callable class into a static hook type.

Parameters:

cls (Callable[[~_P], Callable[[], None]]) – a callable class that takes no arguments and returns None.

Returns:

A class that can be instantiated in the same way to have a static hook.

Return type:

Callable[[~_P], StaticHook]