drytorch.lib.hooks

Module containing registry, callbacks, and hooks for a Trainer.

Functions

call_every(interval[, start])

Create a decorator for periodic hook execution.

static_hook_class(cls)

Class decorator to wrap a callable class into a static hook type.

Classes

CallEvery(wrapped, interval, start)

Call a function at specified intervals.

ChangeSchedulerOnPlateauCallback([metric, ...])

Change the learning rate schedule when a metric has stopped improving.

EarlyStoppingCallback([metric, monitor, ...])

Implement early stopping logic for training models.

Hook(wrapped)

Wrapper for callable taking a Trainer as input.

HookRegistry()

A registry for managing and executing hooks.

MetricExtractor([metric, monitor])

Handle extraction of metrics from trainer/validation protocols.

MetricMonitor([metric, monitor, min_delta, ...])

Handle metric monitoring and alerts when performance stops increasing.

OptionalCallable(wrapped)

Abstract class for callables that execute based on custom conditions.

PruneCallback(thresholds[, metric, monitor, ...])

Implement pruning logic for training models.

ReduceLROnPlateau([metric, monitor, ...])

Reduce the learning rate when a metric has stopped improving.

RestartScheduleOnPlateau([metric, monitor, ...])

Restart the scheduling after plateauing.

StaticHook(wrapped)

Ignoring arguments and execute a wrapped function.

TrainerHook()

Callable supporting bind operations.

class ChangeSchedulerOnPlateauCallback(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), cooldown: int = 0)[source]

Bases: Generic[Output, Target]

Change the learning rate schedule when a metric has stopped improving.

monitor

monitor instance.

Type:

drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]

cooldown

number of calls to skip after changing the schedule.

Type:

int

Initialize.

Parameters:
  • metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.

  • monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.

  • min_delta (float) – minimum change required to qualify as an improvement.

  • patience (int) – number of checks to wait before changing the schedule.

  • best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better. Default ‘auto’ will determine this from initial measurements.

  • filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values. Default gets the last value.

  • cooldown (int) – calls to skip after changing the schedule.

__call__(instance: TrainerProtocol[Any, Target, Output]) None[source]

Check if there is a plateau and reduce the learning rate if needed.

Parameters:

instance (TrainerProtocol[Any, Target, Output]) – Trainer instance to evaluate.

Return type:

None

abstractmethod get_scheduler(epoch: int, scheduler: SchedulerProtocol) SchedulerProtocol[source]

Modify input scheduler.

Parameters:
Returns:

Modified scheduler.

Return type:

SchedulerProtocol

class EarlyStoppingCallback(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 10, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), start_from_epoch: int = 2)[source]

Bases: Generic[Output, Target]

Implement early stopping logic for training models.

monitor

monitor instance.

Type:

drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]

start_from_epoch

start from epoch.

Type:

int

Initialize.

Parameters:
  • metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.

  • monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.

  • min_delta (float) – minimum change required to qualify as an improvement.

  • patience (int) – number of calls to wait before stopping. Default ‘auto’ will determine this from initial measurements.

  • best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better.

  • filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values. Default gets the last value.

  • start_from_epoch (int) – first epoch to start monitoring from.

__call__(instance: TrainerProtocol[Any, Target, Output]) None[source]

Evaluate whether training should be stopped early.

Parameters:

instance (TrainerProtocol[Any, Target, Output]) – Trainer instance to evaluate.

Return type:

None

class Hook(wrapped: Callable[[TrainerProtocol[Input, Target, Output]], None])[source]

Bases: TrainerHook[Input, Target, Output]

Wrapper for callable taking a Trainer as input.

wrapped

the function to be conditionally called.

Type:

collections.abc.Callable[[drytorch.core.protocols.TrainerProtocol[drytorch.lib.hooks.Input, drytorch.lib.hooks.Target, drytorch.lib.hooks.Output]], None]

Initialize.

Parameters:

wrapped (Callable[[TrainerProtocol[Input, Target, Output]], None]) – the function to be conditionally called.

__call__(trainer: TrainerProtocol[Input, Target, Output]) None[source]

Execute the call.

Parameters:

trainer (TrainerProtocol[Input, Target, Output]) – the trainer to pass to the wrapped function.

Return type:

None

class MetricExtractor(metric: ObjectiveProtocol[Any, Any] | str | None = None, monitor: MonitorProtocol | None = None)[source]

Bases: object

Handle extraction of metrics from trainer/validation protocols.

This class is responsible for interfacing with trainer and validation protocols to extract metric values.

metric_spec

the metric specification (name or protocol instance).

Type:

drytorch.core.protocols.ObjectiveProtocol[Any, Any] | str | None

optional_monitor

evaluation protocol to monitor.

Type:

drytorch.core.protocols.MonitorProtocol | None

Initialize.

Parameters:
  • metric (p.ObjectiveProtocol[Any, Any] | str | None) – name of the metric to monitor or metric calculator instance.

  • monitor (p.MonitorProtocol | None) – evaluation protocol to monitor.

property metric_name: str | None

Get the resolved metric name.

extract_metric_value(instance: TrainerProtocol[Input, Target, Output], tracker: MetricTracker[Output, Target]) float[source]

Extract and return the metric value from the instance.

Parameters:
  • instance (TrainerProtocol[Input, Target, Output]) – Trainer instance to extract from.

  • tracker (MetricTracker[Output, Target]) – objectives.MetricTracker to potentially update metric name.

Returns:

The extracted metric value.

Raises:

MetricNotFoundError – if the specified metric is not found.

Return type:

float

get_metric_best_is() Literal['auto', 'higher', 'lower'] | None[source]

Get the best_is preference from the metric if available.

Return type:

Literal[‘auto’, ‘higher’, ‘lower’] | None

class MetricMonitor(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1))[source]

Bases: Generic[Output, Target]

Handle metric monitoring and alerts when performance stops increasing.

metric_tracker

handles metric value tracking and improvement detection.

Type:

drytorch.lib.objectives.MetricTracker[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]

extractor

handles metric extraction from protocols.

Type:

drytorch.lib.hooks.MetricExtractor

Initialize.

Parameters:
  • metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of the metric to monitor or metric calculator instance.

  • monitor (p.MonitorProtocol | None) – evaluation protocol to monitor.

  • min_delta (float) – minimum change required to qualify as an improvement.

  • patience (int) – number of checks to wait before triggering callback.

  • best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better.

  • filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values.

property metric_name: str | None

Get the metric name being monitored.

property best_value: float

Get the best result observed so far.

property filtered_value: float

Get the current filtered value.

property history: list[float]

Get the metric history.

is_better(value: float, reference: float) bool[source]

Check whether to be patient.

Parameters:
Return type:

bool

is_improving() bool[source]

Determine if the model performance is improving.

Return type:

bool

is_patient() bool[source]

Check whether to be patient.

Return type:

bool

record_metric_value(instance: TrainerProtocol[Any, Target, Output]) None[source]

Register a new metric value from a monitored evaluation.

Parameters:

instance (TrainerProtocol[Any, Target, Output]) – Trainer instance to extract metric from.

Raises:

MetricNotFoundError – if the specified metric is not found.

Return type:

None

class OptionalCallable(wrapped: Callable[[TrainerProtocol[Input, Target, Output]], None])[source]

Bases: Hook[Input, Target, Output]

Abstract class for callables that execute based on custom conditions.

Initialize.

Parameters:

wrapped (Callable[[TrainerProtocol[Input, Target, Output]], None]) – the function to be conditionally called.

__call__(trainer: TrainerProtocol[Input, Target, Output]) None[source]

Execute the call.

Parameters:

trainer (TrainerProtocol[Input, Target, Output]) – the trainer to pass to the wrapped function.

Return type:

None

class PruneCallback(thresholds: Mapping[int, float | None], metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1))[source]

Bases: Generic[Output, Target]

Implement pruning logic for training models.

monitor

monitor instance.

Type:

drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]

thresholds

dictionary mapping epochs to pruning thresholds.

Type:

collections.abc.Mapping[int, float | None]

Initialize.

Parameters:
  • thresholds (Mapping[int, float | None]) – dictionary mapping epochs to pruning values.

  • metric (str | p.ObjectiveProtocol[Output, Target] | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.

  • monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.

  • min_delta (float) – minimum change required to qualify as an improvement.

  • best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better. Default ‘auto’ will determine this from initial measurements.

  • filter_fn (Callable[[Sequence[float]], float]) – function to aggregate the intermediate results

  • Default (values.) – gets the last value.

__call__(instance: TrainerProtocol[Any, Target, Output]) None[source]

Evaluate whether training should be stopped early.

Parameters:

instance (TrainerProtocol[Any, Target, Output]) – trainer instance to evaluate.

Return type:

None

class ReduceLROnPlateau(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), factor: float = 0.1, cooldown: int = 0)[source]

Bases: ChangeSchedulerOnPlateauCallback[Output, Target]

Reduce the learning rate when a metric has stopped improving.

monitor

monitor instance.

Type:

drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]

cooldown

number of calls to skip after changing the schedule.

Type:

int

factor

factor by which to reduce the learning rate.

Type:

float

Initialize.

Parameters:
  • metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.

  • monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.

  • min_delta (float) – minimum change required to qualify as an improvement.

  • patience (int) – number of checks to wait before changing the schedule.

  • best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better. Default ‘auto’ will determine this from initial measurements.

  • filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values. Default gets the last value.

  • cooldown (int) – calls to skip after changing the schedule.

  • factor (float) – factor by which to reduce the learning rate.

get_scheduler(epoch: int, scheduler: SchedulerProtocol) SchedulerProtocol[source]

Modify the input scheduler to scale down the learning rate.

Parameters:
Returns:

Modified scheduler.

Return type:

SchedulerProtocol

class TrainerHook[source]

Bases: Generic[Input, Target, Output]

Callable supporting bind operations.

abstractmethod __call__(trainer: TrainerProtocol[Input, Target, Output]) None[source]

Execute the call.

Parameters:

trainer (TrainerProtocol[Input, Target, Output]) – the trainer to pass to the wrapped function.

Return type:

None

bind(f: Callable[[TrainerHook[Input, Target, Output]], TrainerHook[Input, Target, Output]], /) TrainerHook[Input, Target, Output][source]

Allow transformation of the Hook.

Parameters:

f (Callable[[TrainerHook[Input, Target, Output]], TrainerHook[Input, Target, Output]]) – a function specifying the transformation.

Returns:

the transformed Hook.

Return type:

TrainerHook[Input, Target, Output]

call_every(interval: int, start: int = 0) Callable[[Callable[[TrainerProtocol[Input, Target, Output]], None]], CallEvery[Input, Target, Output]][source]

Create a decorator for periodic hook execution.

Parameters:
  • start (int) – the epoch to start calling the hook.

  • interval (int) – the frequency of calling the hook.

Returns:

A decorator that wraps a function in a CallEvery hook.

Return type:

Callable[[Callable[[TrainerProtocol[Input, Target, Output]], None]], CallEvery[Input, Target, Output]]