drytorch.lib.hooks
Module containing registry, callbacks, and hooks for a Trainer.
Functions
|
Create a decorator for periodic hook execution. |
|
Class decorator to wrap a callable class into a static hook type. |
Classes
|
Call a function at specified intervals. |
|
Change the learning rate schedule when a metric has stopped improving. |
|
Implement early stopping logic for training models. |
|
Wrapper for callable taking a Trainer as input. |
|
A registry for managing and executing hooks. |
|
Handle extraction of metrics from trainer/validation protocols. |
|
Handle metric monitoring and alerts when performance stops increasing. |
|
Abstract class for callables that execute based on custom conditions. |
|
Implement pruning logic for training models. |
|
Reduce the learning rate when a metric has stopped improving. |
|
Restart the scheduling after plateauing. |
|
Ignoring arguments and execute a wrapped function. |
Callable supporting bind operations. |
- class ChangeSchedulerOnPlateauCallback(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), cooldown: int = 0)[source]
Bases:
Generic[Output,Target]Change the learning rate schedule when a metric has stopped improving.
- monitor
monitor instance.
- Type:
drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]
Initialize.
- Parameters:
metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.
monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.
min_delta (float) – minimum change required to qualify as an improvement.
patience (int) – number of checks to wait before changing the schedule.
best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better. Default ‘auto’ will determine this from initial measurements.
filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values. Default gets the last value.
cooldown (int) – calls to skip after changing the schedule.
- __call__(instance: TrainerProtocol[Any, Target, Output]) None[source]
Check if there is a plateau and reduce the learning rate if needed.
- Parameters:
instance (TrainerProtocol[Any, Target, Output]) – Trainer instance to evaluate.
- Return type:
None
- abstractmethod get_scheduler(epoch: int, scheduler: SchedulerProtocol) SchedulerProtocol[source]
Modify input scheduler.
- Parameters:
epoch (int) – current epoch.
scheduler (SchedulerProtocol) – scheduler to be modified.
- Returns:
Modified scheduler.
- Return type:
- class EarlyStoppingCallback(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 10, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), start_from_epoch: int = 2)[source]
Bases:
Generic[Output,Target]Implement early stopping logic for training models.
- monitor
monitor instance.
- Type:
drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]
Initialize.
- Parameters:
metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.
monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.
min_delta (float) – minimum change required to qualify as an improvement.
patience (int) – number of calls to wait before stopping. Default ‘auto’ will determine this from initial measurements.
best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better.
filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values. Default gets the last value.
start_from_epoch (int) – first epoch to start monitoring from.
- __call__(instance: TrainerProtocol[Any, Target, Output]) None[source]
Evaluate whether training should be stopped early.
- Parameters:
instance (TrainerProtocol[Any, Target, Output]) – Trainer instance to evaluate.
- Return type:
None
- class Hook(wrapped: Callable[[TrainerProtocol[Input, Target, Output]], None])[source]
Bases:
TrainerHook[Input,Target,Output]Wrapper for callable taking a Trainer as input.
- wrapped
the function to be conditionally called.
- Type:
collections.abc.Callable[[drytorch.core.protocols.TrainerProtocol[drytorch.lib.hooks.Input, drytorch.lib.hooks.Target, drytorch.lib.hooks.Output]], None]
Initialize.
- Parameters:
wrapped (Callable[[TrainerProtocol[Input, Target, Output]], None]) – the function to be conditionally called.
- __call__(trainer: TrainerProtocol[Input, Target, Output]) None[source]
Execute the call.
- Parameters:
trainer (TrainerProtocol[Input, Target, Output]) – the trainer to pass to the wrapped function.
- Return type:
None
- class MetricExtractor(metric: ObjectiveProtocol[Any, Any] | str | None = None, monitor: MonitorProtocol | None = None)[source]
Bases:
objectHandle extraction of metrics from trainer/validation protocols.
This class is responsible for interfacing with trainer and validation protocols to extract metric values.
- metric_spec
the metric specification (name or protocol instance).
- Type:
drytorch.core.protocols.ObjectiveProtocol[Any, Any] | str | None
- optional_monitor
evaluation protocol to monitor.
- Type:
Initialize.
- Parameters:
metric (p.ObjectiveProtocol[Any, Any] | str | None) – name of the metric to monitor or metric calculator instance.
monitor (p.MonitorProtocol | None) – evaluation protocol to monitor.
- extract_metric_value(instance: TrainerProtocol[Input, Target, Output], tracker: MetricTracker[Output, Target]) float[source]
Extract and return the metric value from the instance.
- Parameters:
instance (TrainerProtocol[Input, Target, Output]) – Trainer instance to extract from.
tracker (MetricTracker[Output, Target]) – objectives.MetricTracker to potentially update metric name.
- Returns:
The extracted metric value.
- Raises:
MetricNotFoundError – if the specified metric is not found.
- Return type:
- class MetricMonitor(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1))[source]
Bases:
Generic[Output,Target]Handle metric monitoring and alerts when performance stops increasing.
- metric_tracker
handles metric value tracking and improvement detection.
- Type:
drytorch.lib.objectives.MetricTracker[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]
- extractor
handles metric extraction from protocols.
Initialize.
- Parameters:
metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of the metric to monitor or metric calculator instance.
monitor (p.MonitorProtocol | None) – evaluation protocol to monitor.
min_delta (float) – minimum change required to qualify as an improvement.
patience (int) – number of checks to wait before triggering callback.
best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better.
filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values.
- record_metric_value(instance: TrainerProtocol[Any, Target, Output]) None[source]
Register a new metric value from a monitored evaluation.
- Parameters:
instance (TrainerProtocol[Any, Target, Output]) – Trainer instance to extract metric from.
- Raises:
MetricNotFoundError – if the specified metric is not found.
- Return type:
None
- class OptionalCallable(wrapped: Callable[[TrainerProtocol[Input, Target, Output]], None])[source]
Bases:
Hook[Input,Target,Output]Abstract class for callables that execute based on custom conditions.
Initialize.
- Parameters:
wrapped (Callable[[TrainerProtocol[Input, Target, Output]], None]) – the function to be conditionally called.
- __call__(trainer: TrainerProtocol[Input, Target, Output]) None[source]
Execute the call.
- Parameters:
trainer (TrainerProtocol[Input, Target, Output]) – the trainer to pass to the wrapped function.
- Return type:
None
- class PruneCallback(thresholds: Mapping[int, float | None], metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1))[source]
Bases:
Generic[Output,Target]Implement pruning logic for training models.
- monitor
monitor instance.
- Type:
drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]
- thresholds
dictionary mapping epochs to pruning thresholds.
- Type:
collections.abc.Mapping[int, float | None]
Initialize.
- Parameters:
thresholds (Mapping[int, float | None]) – dictionary mapping epochs to pruning values.
metric (str | p.ObjectiveProtocol[Output, Target] | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.
monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.
min_delta (float) – minimum change required to qualify as an improvement.
best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better. Default ‘auto’ will determine this from initial measurements.
filter_fn (Callable[[Sequence[float]], float]) – function to aggregate the intermediate results
Default (values.) – gets the last value.
- __call__(instance: TrainerProtocol[Any, Target, Output]) None[source]
Evaluate whether training should be stopped early.
- Parameters:
instance (TrainerProtocol[Any, Target, Output]) – trainer instance to evaluate.
- Return type:
None
- class ReduceLROnPlateau(metric: ObjectiveProtocol[Output, Target] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, patience: int = 0, best_is: Literal['auto', 'higher', 'lower'] = 'auto', filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), factor: float = 0.1, cooldown: int = 0)[source]
Bases:
ChangeSchedulerOnPlateauCallback[Output,Target]Reduce the learning rate when a metric has stopped improving.
- monitor
monitor instance.
- Type:
drytorch.lib.hooks.MetricMonitor[drytorch.lib.hooks.Output, drytorch.lib.hooks.Target]
Initialize.
- Parameters:
metric (p.ObjectiveProtocol[Output, Target] | str | None) – name of metric to monitor or metric calculator instance. Defaults to the first metric found.
monitor (MetricMonitor[Output, Target]) – evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.
min_delta (float) – minimum change required to qualify as an improvement.
patience (int) – number of checks to wait before changing the schedule.
best_is (Literal['auto', 'higher', 'lower']) – whether higher or lower metric values are better. Default ‘auto’ will determine this from initial measurements.
filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values. Default gets the last value.
cooldown (int) – calls to skip after changing the schedule.
factor (float) – factor by which to reduce the learning rate.
- get_scheduler(epoch: int, scheduler: SchedulerProtocol) SchedulerProtocol[source]
Modify the input scheduler to scale down the learning rate.
- Parameters:
epoch (int) – not used.
scheduler (SchedulerProtocol) – scheduler to be modified.
- Returns:
Modified scheduler.
- Return type:
- class TrainerHook[source]
Bases:
Generic[Input,Target,Output]Callable supporting bind operations.
- abstractmethod __call__(trainer: TrainerProtocol[Input, Target, Output]) None[source]
Execute the call.
- Parameters:
trainer (TrainerProtocol[Input, Target, Output]) – the trainer to pass to the wrapped function.
- Return type:
None
- bind(f: Callable[[TrainerHook[Input, Target, Output]], TrainerHook[Input, Target, Output]], /) TrainerHook[Input, Target, Output][source]
Allow transformation of the Hook.
- Parameters:
f (Callable[[TrainerHook[Input, Target, Output]], TrainerHook[Input, Target, Output]]) – a function specifying the transformation.
- Returns:
the transformed Hook.
- Return type:
TrainerHook[Input, Target, Output]
- call_every(interval: int, start: int = 0) Callable[[Callable[[TrainerProtocol[Input, Target, Output]], None]], CallEvery[Input, Target, Output]][source]
Create a decorator for periodic hook execution.
- Parameters:
- Returns:
A decorator that wraps a function in a CallEvery hook.
- Return type:
Callable[[Callable[[TrainerProtocol[Input, Target, Output]], None]], CallEvery[Input, Target, Output]]