drytorch.contrib.optuna

Support for optuna.

Functions

get_final_value(trial[, filter_fn])

Calculates a trial's final value from its intermediate reported values.

suggest_overrides(tune_cfg, trial[, ...])

Suggest values for a trial from structured configurations.

Classes

TrialCallback(trial[, filter_fn, metric, ...])

Implements pruning logic for training models.

Exceptions

OptunaError(*args)

Base class for Optuna errors.

exception OptunaError(*args: Any)[source]

Bases: DryTorchError

Base class for Optuna errors.

Initialize.

Parameters:

*args (Any) – arguments to be formatted into the message template.

Return type:

None

class TrialCallback(trial: Trial, filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), metric: ObjectiveProtocol[_Output_contra, _Target_contra] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, best_is: Literal['auto', 'higher', 'lower'] = 'auto')[source]

Bases: Generic[_Output_contra, _Target_contra]

Implements pruning logic for training models.

monitor

Monitor instance

Type:

drytorch.lib.hooks.MetricMonitor

trial

Optuna trial.

Type:

optuna.trial._trial.Trial

reported

Dictionary mapping epochs to reported values.

Type:

dict[int, float]

Initialize.

Parameters:
  • trial (Trial) – Optuna trial

  • filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values.

  • metric (ObjectiveProtocol[_Output_contra, _Target_contra] | str | None) – Name of metric to monitor or metric calculator instance. Defaults to the first metric found.

  • monitor (MetricMonitor) – Evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.

  • min_delta (float) – Minimum change required to qualify as an improvement.

  • best_is (Literal['auto', 'higher', 'lower']) – Whether higher or lower metric values are better. Default ‘auto’ will determine this from the first measurements.

__call__(instance: TrainerProtocol[Any, _Target_contra, _Output_contra]) None[source]

Evaluate whether training should be stopped early.

Parameters:

instance (TrainerProtocol[Any, _Target_contra, _Output_contra]) – Trainer instance to evaluate.

Raises:

optuna.TrialPruned – if the trial should be pruned.

Return type:

None

get_final_value(trial: Trial, filter_fn: Callable[[Sequence[float]], float] | None = None) float[source]

Calculates a trial’s final value from its intermediate reported values.

This function aggregates the intermediate values reported during trial optimization using trial.report().

Important: This function will not work with trials created using study.ask() as these don’t populate the intermediate values in the corresponding FrozenTrial.

Parameters:
  • trial (Trial) – the completed Optuna trial to evaluate.

  • filter_fn (Callable[[Sequence[float]], float] | None) – function to aggregate the trial’s intermediate values. Defaults to min or max depending on the study direction.

Returns:

The aggregated final value for the trial.

Raises:

OptunaError – if the trial has no reported values, or if there’s a trial number mismatch.

Return type:

float

suggest_overrides(tune_cfg: DictConfig, trial: Trial, use_full_name: bool = False) list[str][source]

Suggest values for a trial from structured configurations.

This function helps integrate optuna into hydra by specifying trial parameters present in the hydra run configuration.

The configuration file (loadable with hydra) should follow this structure:

tune:
  params:
    param_name:
      suggest: "suggest_float"  # or other optuna suggest method
      settings:
        low: 0.0
        high: 1.0
    list_param:
      suggest: "suggest_list"
      settings:
        min_length: 1
        max_length: 5
        suggest: "suggest_float"  # method for sampling list elements
        settings:
          low: 0.0
          high: 1.0
overrides: []  # additional static overrides

For ‘suggest_list’ configurations, the settings must specify: - min_length and max_length: bounds for the size of the list. - nested suggest and settings: used to sample each list element.

The resulting values can be used with hydra.initialize and hydra.compose. Example usage:

import hydra

with hydra.initialize(version_base=None, config_path='path/to/config'):
    overrides = suggest_overrides(tune_cfg, trial)
    dict_cfg = hydra.compose(config_name='config', overrides=overrides)

Here, “your_hydra_config” is the name of the configuration file that includes the configuration parameters to override.

Parameters:
  • tune_cfg (DictConfig) – a structure specifying how to sample new parameter values.

  • trial (Trial) – the optuna trial related to the sampled parameters.

  • use_full_name (bool) – use the fully qualified setting name. Default to a human-readable name.

Returns:

A list of strings for hydra configuration overrides.

Raises:

OptunaError – if the suggested configuration is invalid.

Return type:

list[str]