drytorch.contrib.optuna
Support for optuna.
Functions
|
Calculates a trial's final value from its intermediate reported values. |
|
Suggest values for a trial from structured configurations. |
Classes
|
Implements pruning logic for training models. |
Exceptions
|
Base class for Optuna errors. |
- exception OptunaError(*args: Any)[source]
Bases:
DryTorchErrorBase class for Optuna errors.
Initialize.
- Parameters:
*args (Any) – arguments to be formatted into the message template.
- Return type:
None
- class TrialCallback(trial: Trial, filter_fn: Callable[[Sequence[float]], float] = operator.itemgetter(-1), metric: ObjectiveProtocol[_Output_contra, _Target_contra] | str | None = None, monitor: MonitorProtocol | None = None, min_delta: float = 1e-08, best_is: Literal['auto', 'higher', 'lower'] = 'auto')[source]
Bases:
Generic[_Output_contra,_Target_contra]Implements pruning logic for training models.
- monitor
Monitor instance
- trial
Optuna trial.
Initialize.
- Parameters:
trial (Trial) – Optuna trial
filter_fn (Callable[[Sequence[float]], float]) – function to aggregate recent metric values.
metric (ObjectiveProtocol[_Output_contra, _Target_contra] | str | None) – Name of metric to monitor or metric calculator instance. Defaults to the first metric found.
monitor (MetricMonitor) – Evaluation protocol to monitor. Defaults to validation if available, trainer instance otherwise.
min_delta (float) – Minimum change required to qualify as an improvement.
best_is (Literal['auto', 'higher', 'lower']) – Whether higher or lower metric values are better. Default ‘auto’ will determine this from the first measurements.
- __call__(instance: TrainerProtocol[Any, _Target_contra, _Output_contra]) None[source]
Evaluate whether training should be stopped early.
- Parameters:
instance (TrainerProtocol[Any, _Target_contra, _Output_contra]) – Trainer instance to evaluate.
- Raises:
optuna.TrialPruned – if the trial should be pruned.
- Return type:
None
- get_final_value(trial: Trial, filter_fn: Callable[[Sequence[float]], float] | None = None) float[source]
Calculates a trial’s final value from its intermediate reported values.
This function aggregates the intermediate values reported during trial optimization using trial.report().
Important: This function will not work with trials created using study.ask() as these don’t populate the intermediate values in the corresponding FrozenTrial.
- Parameters:
- Returns:
The aggregated final value for the trial.
- Raises:
OptunaError – if the trial has no reported values, or if there’s a trial number mismatch.
- Return type:
- suggest_overrides(tune_cfg: DictConfig, trial: Trial, use_full_name: bool = False) list[str][source]
Suggest values for a trial from structured configurations.
This function helps integrate optuna into hydra by specifying trial parameters present in the hydra run configuration.
The configuration file (loadable with hydra) should follow this structure:
tune: params: param_name: suggest: "suggest_float" # or other optuna suggest method settings: low: 0.0 high: 1.0 list_param: suggest: "suggest_list" settings: min_length: 1 max_length: 5 suggest: "suggest_float" # method for sampling list elements settings: low: 0.0 high: 1.0 overrides: [] # additional static overrides
For ‘suggest_list’ configurations, the settings must specify: - min_length and max_length: bounds for the size of the list. - nested suggest and settings: used to sample each list element.
The resulting values can be used with hydra.initialize and hydra.compose. Example usage:
import hydra with hydra.initialize(version_base=None, config_path='path/to/config'): overrides = suggest_overrides(tune_cfg, trial) dict_cfg = hydra.compose(config_name='config', overrides=overrides)
Here, “your_hydra_config” is the name of the configuration file that includes the configuration parameters to override.
- Parameters:
- Returns:
A list of strings for hydra configuration overrides.
- Raises:
OptunaError – if the suggested configuration is invalid.
- Return type: