drytorch.lib.hooks
Module containing registry, callbacks, and hooks for a Trainer.
Functions
|
Create a transformer for periodic hook execution. |
|
Class decorator to wrap a callable class into a static hook type. |
Classes
|
Metadata-aware wrapper for periodic execution. |
|
Change the learning rate schedule when a metric has stopped improving. |
|
Implement early stopping logic for training models. |
|
Wrapper for callable taking a Trainer as input. |
A registry for managing and executing hooks. |
|
|
Handle extraction of metrics from trainer/validation protocols. |
|
Handle metric monitoring and alerts when performance stops increasing. |
|
Implement pruning logic for training models. |
|
Reduce the learning rate when a metric has stopped improving. |
|
Restart the scheduling after plateauing. |
|
Ignoring arguments and execute a wrapped function. |
Compose control flow on side effects while keeping track of parameters. |
- class CallEvery(wrapped: Callable[[TrainerProtocol[Input, Target, Output]], None], parameters: dict[str, Any] | None = None, interval: int = 1, start: int = 0)[source]
Bases:
Hook[Input,Target,Output]Metadata-aware wrapper for periodic execution.
Initialize.
- Parameters:
- __call__(trainer: TrainerProtocol[Input, Target, Output]) None[source]
Optionally, calling wrapped callable.
- Parameters:
trainer (TrainerProtocol[Input, Target, Output])
- Return type:
None
- class ChangeSchedulerOnPlateauCallback(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), cooldown: int = 0)[source]
Bases:
Generic[Output,Target]Change the learning rate schedule when a metric has stopped improving.
- monitor
monitor instance.
- Type:
drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]
Initialize.
- Parameters:
metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.
monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.
min_delta (float) – minimum change required to qualify as an improvement.
patience (int) – number of checks to wait before changing the schedule.
best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better. Default ‘auto’ will determine this from initial measurements.
filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values. Default gets the last value.
cooldown (int) – calls to skip after changing the schedule.
- __call__(instance: TrainerProtocol[Any, Target, Output]) None[source]
Check if there is a plateau and reduce the learning rate if needed.
- Parameters:
instance (TrainerProtocol[Any, Target, Output]) – Trainer instance to evaluate.
- Return type:
None
- abstractmethod get_scheduler(epoch: int, scheduler: SchedulerProtocol) SchedulerProtocol[source]
Modify input scheduler.
- Parameters:
epoch (int) – current epoch.
scheduler (SchedulerProtocol) – scheduler to be modified.
- Returns:
Modified scheduler.
- Return type:
- class EarlyStoppingCallback(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 10, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), start_from_epoch: int = 2)[source]
Bases:
Generic[Output,Target]Implement early stopping logic for training models.
- monitor
monitor instance.
- Type:
drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]
Initialize.
- Parameters:
metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.
monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.
min_delta (float) – minimum change required to qualify as an improvement.
patience (int) – number of calls to wait before stopping. Default ‘auto’ will determine this from initial measurements.
best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better.
filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values. Default gets the last value.
start_from_epoch (int) – first epoch to start monitoring from.
- __call__(instance: TrainerProtocol[Any, Target, Output]) None[source]
Evaluate whether training should be stopped early.
- Parameters:
instance (TrainerProtocol[Any, Target, Output]) – Trainer instance to evaluate.
- Return type:
None
- class Hook(wrapped: Callable[[TrainerProtocol[Input, Target, Output]], None], base_hook_name: str | None = None, parameters: dict[str, Any] | None = None)[source]
Bases:
TrainerHook[Input,Target,Output]Wrapper for callable taking a Trainer as input.
Initialize.
- Parameters:
- __call__(trainer: TrainerProtocol[Input, Target, Output]) None[source]
Execute the call.
- Parameters:
trainer (TrainerProtocol[Input, Target, Output]) – the trainer to pass to the wrapped function.
- Return type:
None
- class HookRegistry[source]
Bases:
Generic[_T_contra]A registry for managing and executing hooks.
The hooks have a generic object as input and can access it.
- hooks
a list of registered hooks.
- Type:
list[collections.abc.Callable[[drytorch.lib.hooks._T_contra], None]]
Initialize.
- execute(input_object: _T_contra) None[source]
Execute the registered hooks in order of registration.
- Parameters:
input_object (_T_contra) – the input to pass to each hook.
- Return type:
None
- class MetricExtractor(metric: ObjectiveProtocol[Any, Any] | str | None = None, monitor: MonitorProtocol | None = None)[source]
Bases:
objectHandle extraction of metrics from trainer/validation protocols.
This class is responsible for interfacing with trainer and validation protocols to extract metric values.
- metric_spec
the metric specification (name or protocol instance).
- Type:
drytorch.core.protocols.ObjectiveProtocol[Any, Any] | str | None
- optional_monitor
evaluation protocol to monitor.
- Type:
Initialize.
- Parameters:
metric (p.ObjectiveProtocol[Any, Any] | str | None) – name of the metric to monitor or metric calculator instance.
monitor (p.MonitorProtocol | None) – evaluation protocol to monitor.
- extract_metric_value(instance: TrainerProtocol[Input, Target, Output], tracker: MetricTracker[Output, Target]) float[source]
Extract and return the metric value from the instance.
- Parameters:
instance (TrainerProtocol[Input, Target, Output]) – Trainer instance to extract from.
tracker (MetricTracker[Output, Target]) – objectives.MetricTracker to potentially update metric name.
- Returns:
The extracted metric value.
- Raises:
MetricNotFoundError – if the specified metric is not found.
- Return type:
- class MetricMonitor(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1))[source]
Bases:
Generic[Output,Target]Handle metric monitoring and alerts when performance stops increasing.
- metric_tracker
handles metric value tracking and improvement detection.
- Type:
drytorch.lib.objectives.MetricTracker[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]
- extractor
handles metric extraction from protocols.
Initialize.
- Parameters:
metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of the metric to monitor or metric calculator instance.
monitor (p.MonitorProtocol | None) – evaluation protocol to monitor.
min_delta (float) – minimum change required to qualify as an improvement.
patience (int) – number of checks to wait before triggering callback.
best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better.
filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values.
- record_metric_value(instance: TrainerProtocol[Any, Target, Output]) None[source]
Register a new metric value from a monitored evaluation.
- Parameters:
instance (TrainerProtocol[Any, Target, Output]) – Trainer instance to extract metric from.
- Raises:
MetricNotFoundError – if the specified metric is not found.
- Return type:
None
- class PruneCallback(thresholds: Mapping[int, float | None], metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1))[source]
Bases:
Generic[Output,Target]Implement pruning logic for training models.
- monitor
monitor instance.
- Type:
drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]
- thresholds
dictionary mapping epochs to pruning thresholds.
- Type:
collections.abc.Mapping[int, float | None]
Initialize.
- Parameters:
thresholds (Mapping[int, float | None]) – dictionary mapping epochs to pruning values.
metric (str | p.ObjectiveProtocol[Output, Target] | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.
monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.
min_delta (float) – minimum change required to qualify as an improvement.
best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better. Default ‘auto’ will determine this from initial measurements.
filter_fn (Callable[[Sequence[float]], float]) – function to aggregate the intermediate results
Default (values.) – gets the last value.
- __call__(instance: TrainerProtocol[Any, Target, Output]) None[source]
Evaluate whether training should be stopped early.
- Parameters:
instance (TrainerProtocol[Any, Target, Output]) – trainer instance to evaluate.
- Return type:
None
- class ReduceLROnPlateau(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), factor: float = 0.1, cooldown: int = 0)[source]
Bases:
ChangeSchedulerOnPlateauCallback[Output,Target]Reduce the learning rate when a metric has stopped improving.
- monitor
monitor instance.
- Type:
drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]
Initialize.
- Parameters:
metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.
monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.
min_delta (float) – minimum change required to qualify as an improvement.
patience (int) – number of checks to wait before changing the schedule.
best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better. Default ‘auto’ will determine this from initial measurements.
filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values. Default gets the last value.
cooldown (int) – calls to skip after changing the schedule.
factor (float) – factor by which to reduce the learning rate.
- get_scheduler(epoch: int, scheduler: SchedulerProtocol) SchedulerProtocol[source]
Modify the input scheduler to scale down the learning rate.
- Parameters:
epoch (int) – not used.
scheduler (SchedulerProtocol) – scheduler to be modified.
- Returns:
Modified scheduler.
- Return type:
- class RestartScheduleOnPlateau(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), cooldown: int = 0)[source]
Bases:
ChangeSchedulerOnPlateauCallback[Output,Target]Restart the scheduling after plateauing.
- monitor
monitor instance.
- Type:
drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]
Initialize.
- Parameters:
metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.
monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.
min_delta (float) – minimum change required to qualify as an improvement.
patience (int) – number of checks to wait before changing the schedule.
best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better. Default ‘auto’ will determine this from initial measurements.
filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values. Default gets the last value.
cooldown (int) – calls to skip after changing the schedule.
- get_scheduler(epoch: int, scheduler: SchedulerProtocol) SchedulerProtocol[source]
Consider training until now a warm-up and restart scheduling.
- Parameters:
epoch (int) – current epoch.
scheduler (SchedulerProtocol) – scheduler to be modified.
- Returns:
Modified scheduler.
- Return type:
- class StaticHook(wrapped: Callable[[], None], base_hook_name: str | None = None, parameters: dict[str, Any] | None = None)[source]
Bases:
TrainerHook[Any,Any,Any]Ignoring arguments and execute a wrapped function.
Initialize.
- Parameters:
- __call__(trainer: TrainerProtocol[Input, Target, Output]) None[source]
Execute the call.
- Parameters:
trainer (TrainerProtocol[Input, Target, Output]) – the trainer to pass to the wrapped function.
- Return type:
None
- class TrainerHook[source]
Bases:
Generic[Input,Target,Output]Compose control flow on side effects while keeping track of parameters.
- Parameters:
parameters – metadata associated with the hook
base_hook_name – name of the base hook for representation
- abstractmethod __call__(trainer: TrainerProtocol[Input, Target, Output]) None[source]
Execute the call.
- Parameters:
trainer (TrainerProtocol[Input, Target, Output]) – the trainer to pass to the wrapped function.
- Return type:
None
- bind(f: Callable[[Callable[[TrainerProtocol[Input, Target, Output]], None]], Hook], /) Hook[Input, Target, Output][source]
Allow transformation of the Hook.
- Parameters:
f (Callable[[Callable[[TrainerProtocol[Input, Target, Output]], None]], Hook]) – a function specifying the transformation.
- Returns:
the transformed Hook.
- Return type:
Hook[Input, Target, Output]
- call_every(interval: int, start: int = 0) Callable[[Callable[[TrainerProtocol[Input, Target, Output]], None]], Hook[Input, Target, Output]][source]
Create a transformer for periodic hook execution.
- static_hook_class(cls: Callable[[_P], Callable[[], None]]) Callable[[_P], StaticHook][source]
Class decorator to wrap a callable class into a static hook type.
- Parameters:
cls (Callable[[~_P], Callable[[], None]]) – a callable class that takes no arguments and returns None.
- Returns:
A class that can be instantiated in the same way to have a static hook.
- Return type:
Callable[[~_P], StaticHook]