drytorch.trackers.base_classes

Module containing abstract classes for trackers.

Classes

BasePlotter([model_names, source_names, ...])

Abstract class for plotting trajectory from sources.

Dumper([par_dir])

Tracker with a standard folder structure.

MemoryMetrics([metric_loader])

Keep all metrics in memory.

MetricLoader()

Interface for trackers that load metrics.

class Dumper(par_dir: Path | None = None)[source]

Bases: Tracker

Tracker with a standard folder structure.

folder_name

name of the folder containing the output.

Type:

ClassVar[str]

user_par_dir

parent directory for the tracker data.

Type:

pathlib.Path | None

_par_dir

parent directory set by experiment.

Type:

pathlib.Path | None

_exp_name

experiment name set by experiment.

Type:

str | None

_run_id

run identifier set by experiment.

Type:

str | None

Initialize.

Parameters:

par_dir (Path | None) – the parent directory for the tracker data. Default uses the same of the current experiment.

property par_dir: Path

Return the parent directory for the experiments.

Raises:

AccessOutsideScopeError – when the default folder is not available.

property run_id: str

Return the identifier for the experiment run.

Raises:

AccessOutsideScopeError – when the id of the run is not available.

property exp_name: str

Return the name of the experiment.

Raises:

AccessOutsideScopeError – when the name id not available.

clean_up() None[source]

Remove experimental data from the tracker.

Return type:

None

notify(event: Event) None[source]
notify(event: StartExperimentEvent) None
notify(event: StopExperimentEvent) None

Notify the tracker of an event.

Parameters:

event (Event) – the event to notify about.

Return type:

None

class MetricLoader[source]

Bases: Tracker, ABC

Interface for trackers that load metrics.

load_metrics(model_name: str, max_epoch: int = -1) dict[str, tuple[list[int], dict[str, list[float]]]][source]

Load metrics from the last run of the experiment.

Parameters:
  • model_name (str) – the name of the model.

  • max_epoch (int) – the maximum epoch to load. Defaults to all.

Returns:

The current epochs and named metric values by the source.

Raises:

ValueError – if max_epoch is less than -1.

Return type:

dict[str, tuple[list[int], dict[str, list[float]]]]

class MemoryMetrics(metric_loader: MetricLoader | None = None)[source]

Bases: Tracker

Keep all metrics in memory.

model_dict

all metrics recorded in this session.

Type:

dict[str, dict[str, tuple[list[int], dict[str, list[float]]]]]

Initialize.

Parameters:

metric_loader (MetricLoader | None) – object to load the metrics.

notify(event: Event) None[source]
notify(event: MetricEvent) None
notify(event: LoadModelEvent) None

Notify the tracker of an event.

Parameters:

event (Event) – the event to notify about.

Return type:

None

class BasePlotter(model_names: Iterable[str] = (), source_names: Iterable[str] = (), metric_names: Iterable[str] = (), start: int = 1, metric_loader: MetricLoader | None = None)[source]

Bases: MemoryMetrics, ABC, Generic[Plot]

Abstract class for plotting trajectory from sources.

_model_names

names of the models to plot.

Type:

collections.abc.Iterable[str]

_source_names

names of the sources to plot.

Type:

collections.abc.Iterable[str]

_metric_names

names of the metrics to plot.

Type:

collections.abc.Iterable[str]

_start

epoch from which to start plotting.

Type:

int

_removed_start

flag indicating if start epochs were removed.

Type:

bool

Initialize.

Parameters:
  • model_names (Iterable[str]) – the names of the models to plot. Defaults to all.

  • source_names (Iterable[str]) – the names of the sources to plot. Defaults to all.

  • metric_names (Iterable[str]) – the names of the metrics to plot. Defaults to all.

  • start (int) – if positive, the epoch from which to start plotting; if negative, the last number of epochs. Defaults to all.

  • metric_loader (MetricLoader | None) – a tracker that can load metrics from a previous run.

Note

start_epoch allows you to exclude the initial epochs from the graph. During the first 2 * start_epoch epochs, the graph is shown in its entirety.

notify(event: Event) None[source]
notify(event: EndEpochEvent) None
notify(event: EndTestEvent) None

Notify the tracker of an event.

Parameters:

event (Event) – the event to notify about.

Return type:

None

plot(model_name: str, source_names: Iterable[str] = (), metric_names: Iterable[str] = (), start_epoch: int = 1) list[Plot][source]

Plot the learning curves.

Parameters:
  • model_name (str) – the name of the model to plot.

  • source_names (Iterable[str]) – the names of the sources to plot. Defaults to all.

  • metric_names (Iterable[str]) – the metric to plot. Defaults to all.

  • start_epoch (int) – the epoch from which to start plotting.

Returns:

References to the plot objects or windows depending on the backend.

Raises:
Return type:

list[Plot]