Source code for drytorch.contrib.optuna

"""Support for optuna."""

import string

from collections.abc import Callable, Sequence
from typing import Any, Final, Generic, Literal, TypeVar

import optuna

from omegaconf import DictConfig

from drytorch.core import exceptions
from drytorch.core import protocols as p
from drytorch.lib import hooks


__all__ = [
    'OptunaError',
    'TrialCallback',
    'get_final_value',
    'suggest_overrides',
]

_Target_contra = TypeVar(
    '_Target_contra', bound=p.TargetType, contravariant=True
)
_Output_contra = TypeVar(
    '_Output_contra', bound=p.OutputType, contravariant=True
)


[docs] class OptunaError(exceptions.DryTorchError): """Base class for Optuna errors.""" _template = 'Optuna: {}'
[docs] class TrialCallback(Generic[_Output_contra, _Target_contra]): """Implements pruning logic for training models. Attributes: monitor: Monitor instance trial: Optuna trial. reported: Dictionary mapping epochs to reported values. """ monitor: hooks.MetricMonitor trial: optuna.Trial reported: dict[int, float] def __init__( self, trial: optuna.Trial, filter_fn: Callable[[Sequence[float]], float] = hooks.get_last, metric: p.ObjectiveProtocol[_Output_contra, _Target_contra] | str | None = None, monitor: p.MonitorProtocol | None = None, min_delta: float = 1e-8, best_is: Literal['auto', 'higher', 'lower'] = 'auto', ) -> None: """Initialize. Args: trial: Optuna trial filter_fn: function to aggregate recent metric values. metric: Name of metric to monitor or metric calculator instance. Defaults to the first metric found. monitor: Evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise. min_delta: Minimum change required to qualify as an improvement. best_is: Whether higher or lower metric values are better. Default 'auto' will determine this from the first measurements. """ self.monitor: Final = hooks.MetricMonitor( metric=metric, monitor=monitor, min_delta=min_delta, best_is=best_is, filter_fn=filter_fn, ) self.trial: Final = trial self.reported: Final = dict[int, float]() return
[docs] def __call__( self, instance: p.TrainerProtocol[Any, _Target_contra, _Output_contra], ) -> None: """Evaluate whether training should be stopped early. Args: instance: Trainer instance to evaluate. Raises: optuna.TrialPruned: if the trial should be pruned. """ self.monitor.record_metric_value(instance) epoch = instance.model.epoch value = self.monitor.filtered_value self.trial.report(value, epoch) self.reported[epoch] = value if self.trial.should_prune(): metric_name = self.monitor.metric_name msg = f'Optuna pruning while monitoring {metric_name}' instance.terminate_training(msg) raise optuna.TrialPruned() return
[docs] def suggest_overrides( tune_cfg: DictConfig, trial: optuna.Trial, use_full_name: bool = False ) -> list[str]: """Suggest values for a trial from structured configurations. This function helps integrate optuna into hydra by specifying trial parameters present in the hydra run configuration. The configuration file (loadable with hydra) should follow this structure: .. code-block:: yaml tune: params: param_name: suggest: "suggest_float" # or other optuna suggest method settings: low: 0.0 high: 1.0 list_param: suggest: "suggest_list" settings: min_length: 1 max_length: 5 suggest: "suggest_float" # method for sampling list elements settings: low: 0.0 high: 1.0 overrides: [] # additional static overrides For 'suggest_list' configurations, the settings must specify: - min_length and max_length: bounds for the size of the list. - nested suggest and settings: used to sample each list element. The resulting values can be used with hydra.initialize and hydra.compose. Example usage: .. code-block:: python import hydra with hydra.initialize(version_base=None, config_path='path/to/config'): overrides = suggest_overrides(tune_cfg, trial) dict_cfg = hydra.compose(config_name='config', overrides=overrides) Here, "your_hydra_config" is the name of the configuration file that includes the configuration parameters to override. Args: tune_cfg: a structure specifying how to sample new parameter values. trial: the optuna trial related to the sampled parameters. use_full_name: use the fully qualified setting name. Default to a human-readable name. Returns: A list of strings for hydra configuration overrides. Raises: OptunaError: if the suggested configuration is invalid. """ all_overrides: list[str] = [*tune_cfg.overrides] for setting_name, param_value in tune_cfg.tune.params.items(): if use_full_name: param_name = setting_name else: *prefix_parts, param_name = setting_name.rsplit('.', maxsplit=2) if param_name.isdigit() and prefix_parts: param_name = f'{prefix_parts[-1]} {param_name}' param_name = string.capwords(param_name.replace('_', ' ')) if param_value.suggest == 'suggest_list': new_value = [] for i in range( trial.suggest_int( name=f'{param_name} #', low=param_value.settings.min_length, high=param_value.settings.max_length, ) ): try: bound_suggest = getattr(trial, param_value.settings.suggest) except AttributeError as ae: msg = f'Invalid suggest configuration: {ae}.' raise OptunaError(msg) from ae new_value.append( bound_suggest( f'{param_name} {i}', **param_value.settings.settings ) ) else: try: bound_suggest = getattr(trial, param_value.suggest) except AttributeError as ae: msg = f'Invalid suggest configuration: {ae}.' raise OptunaError(msg) from ae new_value = bound_suggest(param_name, **param_value.settings) all_overrides.append(f'{setting_name}={new_value}') return all_overrides
[docs] def get_final_value( trial: optuna.Trial, filter_fn: Callable[[Sequence[float]], float] | None = None, ) -> float: """Calculates a trial's final value from its intermediate reported values. This function aggregates the intermediate values reported during trial optimization using trial.report(). Important: This function will not work with trials created using study.ask() as these don't populate the intermediate values in the corresponding FrozenTrial. Args: trial: the completed Optuna trial to evaluate. filter_fn: function to aggregate the trial's intermediate values. Defaults to min or max depending on the study direction. Returns: The aggregated final value for the trial. Raises: OptunaError: if the trial has no reported values, or if there's a trial number mismatch. """ current_study = trial.study if filter_fn is None: filter_fn = min if current_study.direction.name == 'MINIMIZE' else max for frozen_trial in reversed(current_study.trials): if frozen_trial.number == trial.number: break else: raise OptunaError('trial number mismatch.') reported_values = list(frozen_trial.intermediate_values.values()) if not reported_values: msg = 'trial has no reported values. Did you use study.optimize?' raise OptunaError(msg) return filter_fn(reported_values)