"""Support for optuna."""
import string
from collections.abc import Callable, Sequence
from typing import Any, Final, Generic, Literal, TypeVar
import optuna
from omegaconf import DictConfig
from drytorch.core import exceptions
from drytorch.core import protocols as p
from drytorch.lib import hooks
__all__ = [
'OptunaError',
'TrialCallback',
'get_final_value',
'suggest_overrides',
]
_Target_contra = TypeVar(
'_Target_contra', bound=p.TargetType, contravariant=True
)
_Output_contra = TypeVar(
'_Output_contra', bound=p.OutputType, contravariant=True
)
[docs]
class OptunaError(exceptions.DryTorchError):
"""Base class for Optuna errors."""
_template = 'Optuna: {}'
[docs]
class TrialCallback(Generic[_Output_contra, _Target_contra]):
"""Implements pruning logic for training models.
Attributes:
monitor: Monitor instance
trial: Optuna trial.
reported: Dictionary mapping epochs to reported values.
"""
monitor: hooks.MetricMonitor
trial: optuna.Trial
reported: dict[int, float]
def __init__(
self,
trial: optuna.Trial,
filter_fn: Callable[[Sequence[float]], float] = hooks.get_last,
metric: p.ObjectiveProtocol[_Output_contra, _Target_contra]
| str
| None = None,
monitor: p.MonitorProtocol | None = None,
min_delta: float = 1e-8,
best_is: Literal['auto', 'higher', 'lower'] = 'auto',
) -> None:
"""Initialize.
Args:
trial: Optuna trial
filter_fn: function to aggregate recent metric values.
metric: Name of metric to monitor or metric calculator instance.
Defaults to the first metric found.
monitor: Evaluation protocol to monitor. Defaults to validation
if available, trainer instance otherwise.
min_delta: Minimum change required to qualify as an improvement.
best_is: Whether higher or lower metric values are better. Default
'auto' will determine this from the first measurements.
"""
self.monitor: Final = hooks.MetricMonitor(
metric=metric,
monitor=monitor,
min_delta=min_delta,
best_is=best_is,
filter_fn=filter_fn,
)
self.trial: Final = trial
self.reported: Final = dict[int, float]()
return
[docs]
def __call__(
self,
instance: p.TrainerProtocol[Any, _Target_contra, _Output_contra],
) -> None:
"""Evaluate whether training should be stopped early.
Args:
instance: Trainer instance to evaluate.
Raises:
optuna.TrialPruned: if the trial should be pruned.
"""
self.monitor.record_metric_value(instance)
epoch = instance.model.epoch
value = self.monitor.filtered_value
self.trial.report(value, epoch)
self.reported[epoch] = value
if self.trial.should_prune():
metric_name = self.monitor.metric_name
msg = f'Optuna pruning while monitoring {metric_name}'
instance.terminate_training(msg)
raise optuna.TrialPruned()
return
[docs]
def suggest_overrides(
tune_cfg: DictConfig, trial: optuna.Trial, use_full_name: bool = False
) -> list[str]:
"""Suggest values for a trial from structured configurations.
This function helps integrate optuna into hydra by specifying trial
parameters present in the hydra run configuration.
The configuration file (loadable with hydra) should follow this structure:
.. code-block:: yaml
tune:
params:
param_name:
suggest: "suggest_float" # or other optuna suggest method
settings:
low: 0.0
high: 1.0
list_param:
suggest: "suggest_list"
settings:
min_length: 1
max_length: 5
suggest: "suggest_float" # method for sampling list elements
settings:
low: 0.0
high: 1.0
overrides: [] # additional static overrides
For 'suggest_list' configurations, the settings must specify:
- min_length and max_length: bounds for the size of the list.
- nested suggest and settings: used to sample each list element.
The resulting values can be used with hydra.initialize and hydra.compose.
Example usage:
.. code-block:: python
import hydra
with hydra.initialize(version_base=None, config_path='path/to/config'):
overrides = suggest_overrides(tune_cfg, trial)
dict_cfg = hydra.compose(config_name='config', overrides=overrides)
Here, "your_hydra_config" is the name of the configuration file that
includes the configuration parameters to override.
Args:
tune_cfg: a structure specifying how to sample new parameter values.
trial: the optuna trial related to the sampled parameters.
use_full_name: use the fully qualified setting name. Default to a
human-readable name.
Returns:
A list of strings for hydra configuration overrides.
Raises:
OptunaError: if the suggested configuration is invalid.
"""
all_overrides: list[str] = [*tune_cfg.overrides]
for setting_name, param_value in tune_cfg.tune.params.items():
if use_full_name:
param_name = setting_name
else:
*prefix_parts, param_name = setting_name.rsplit('.', maxsplit=2)
if param_name.isdigit() and prefix_parts:
param_name = f'{prefix_parts[-1]} {param_name}'
param_name = string.capwords(param_name.replace('_', ' '))
if param_value.suggest == 'suggest_list':
new_value = []
for i in range(
trial.suggest_int(
name=f'{param_name} #',
low=param_value.settings.min_length,
high=param_value.settings.max_length,
)
):
try:
bound_suggest = getattr(trial, param_value.settings.suggest)
except AttributeError as ae:
msg = f'Invalid suggest configuration: {ae}.'
raise OptunaError(msg) from ae
new_value.append(
bound_suggest(
f'{param_name} {i}', **param_value.settings.settings
)
)
else:
try:
bound_suggest = getattr(trial, param_value.suggest)
except AttributeError as ae:
msg = f'Invalid suggest configuration: {ae}.'
raise OptunaError(msg) from ae
new_value = bound_suggest(param_name, **param_value.settings)
all_overrides.append(f'{setting_name}={new_value}')
return all_overrides
[docs]
def get_final_value(
trial: optuna.Trial,
filter_fn: Callable[[Sequence[float]], float] | None = None,
) -> float:
"""Calculates a trial's final value from its intermediate reported values.
This function aggregates the intermediate values reported during trial
optimization using trial.report().
Important: This function will not work with trials created using study.ask()
as these don't populate the intermediate values in the corresponding
FrozenTrial.
Args:
trial: the completed Optuna trial to evaluate.
filter_fn: function to aggregate the trial's intermediate values.
Defaults to min or max depending on the study direction.
Returns:
The aggregated final value for the trial.
Raises:
OptunaError: if the trial has no reported values, or if there's
a trial number mismatch.
"""
current_study = trial.study
if filter_fn is None:
filter_fn = min if current_study.direction.name == 'MINIMIZE' else max
for frozen_trial in reversed(current_study.trials):
if frozen_trial.number == trial.number:
break
else:
raise OptunaError('trial number mismatch.')
reported_values = list(frozen_trial.intermediate_values.values())
if not reported_values:
msg = 'trial has no reported values. Did you use study.optimize?'
raise OptunaError(msg)
return filter_fn(reported_values)