"""Module containing a tqdm tracker for progress bars."""
from __future__ import annotations
import functools
import sys
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, ClassVar
import tqdm.auto as tqdm
from typing_extensions import override
from drytorch.core import log_events, tracking
__all__ = [
'EpochBar',
'TqdmLogger',
'TrainingBar',
]
if TYPE_CHECKING:
from _typeshed import SupportsWrite
[docs]
class EpochBar:
"""Bar that displays the current epoch's metrics and progress.
This class is also used to display metrics and progress during evaluation.
Attributes:
fmt: the formatting of the bar.
seen_str: the name for the elements of the batches.
color: the color of the bar.
pbar: the wrapped tqdm bar.
"""
fmt: ClassVar[str] = (
'{l_bar}{bar}| {n_fmt}/{total_fmt}, {elapsed}<{remaining}{postfix}'
)
seen_str: ClassVar[str] = 'Samples'
color: ClassVar[str] = 'green'
pbar: tqdm.tqdm
_batch_size: int | None
_n_samples: int
_n_iter: int
_epoch_seen: int
def __init__(
self,
batch_size: int | None,
n_iter: int,
n_samples: int,
leave: bool,
file: SupportsWrite[str],
desc: str,
position: int = 0,
) -> None:
"""Initialize.
Args:
batch_size: how many samples are in one batch.
n_iter: the number of expected iterations.
n_samples: the total number of samples.
leave: whether to leave the bar in after the epoch.
file: the stream where to flush the bar.
desc: description to contextualize the bar.
position: position of the bar in the progress bar group.
"""
self._batch_size = batch_size
self._n_samples = n_samples
self._n_iter = n_iter
is_tqdm_notebook = tqdm.tqdm.__mro__[1].__module__.endswith('notebook')
self.pbar = tqdm.tqdm(
total=n_iter,
leave=leave,
file=file,
desc=desc,
bar_format=self.fmt,
colour=self.color,
position=position,
disable=not leave and is_tqdm_notebook,
)
self._epoch_seen = 0
return
[docs]
def update(self, metrics: Mapping[str, Any], n_processes: int) -> None:
"""Update the bar and displays last batch metrics values.
Args:
metrics: the values from the last batch by metric name.
n_processes: the number of processes used for data loading.
"""
monitor_seen: dict[str, int | str]
last_epoch = self.pbar.n >= self._n_iter - n_processes
if self._batch_size is not None:
self._epoch_seen += self._batch_size * n_processes
if last_epoch:
self._epoch_seen = min(self._epoch_seen, self._n_samples)
monitor_seen = {self.seen_str: self._epoch_seen}
else:
monitor_seen = {self.seen_str: '?'}
monitor_metric = {
metric_name: f'{metric_value:.3e}'
for metric_name, metric_value in metrics.items()
}
monitor_dict = monitor_seen | monitor_metric
self.pbar.set_postfix(monitor_dict, refresh=False)
self.pbar.update(n_processes)
if last_epoch:
self.pbar.close()
return
[docs]
class TrainingBar:
"""Create a bar for the training progress.
Class Attributes:
fmt: the formatting of the bar.
desc: the name for the iteration.
color: the color of the bar.
Attributes:
pbar: the wrapped tqdm bar.
"""
fmt: ClassVar[str] = (
'{l_bar}{bar}| {n_fmt}/{total_fmt}, {elapsed}<{remaining}'
)
desc: ClassVar[str] = 'Epoch'
color: ClassVar[str] = 'blue'
pbar: tqdm.tqdm
_start_epoch: int
_end_epoch: int
def __init__(
self,
start_epoch: int,
end_epoch: int,
file: SupportsWrite[str],
leave: bool,
) -> None:
"""Initialize.
Args:
start_epoch: the epoch from which the bar should start.
end_epoch: the epoch where the bar should end.
file: the stream where to flush the bar.
leave: If True, leave the bar once the iterations have completed.
"""
self.pbar = tqdm.trange(
start_epoch,
end_epoch,
desc=f'{self.desc}:',
leave=leave,
file=file,
bar_format=self.fmt,
colour=self.color,
)
self._start_epoch = start_epoch
self._end_epoch = end_epoch
[docs]
def update(self, current_epoch: int) -> None:
"""Update the bar and display the current epoch.
Args:
current_epoch: the current epoch.
"""
self.pbar.update()
description = f'{self.desc}: {current_epoch} / {self._end_epoch}'
self.pbar.set_description(description)
return
[docs]
class TqdmLogger(tracking.Tracker):
"""Create an epoch progress bar."""
_leave: bool
_file: SupportsWrite[str]
_enable_training_bar: bool
_training_bar: TrainingBar | None
_epoch_bar: EpochBar | None
def __init__(
self,
leave: bool = True,
enable_training_bar: bool = False,
file: SupportsWrite[str] = sys.stderr,
) -> None:
"""Initialize.
Args:
leave: whether to leave the epoch bar after completion.
enable_training_bar: create a bar for the overall training progress.
file: the stream where to flush the bar.
Note:
Enable the training bar only if two progress bars are supported,
and there is no other logger or printer streaming.
"""
super().__init__()
self._leave = leave
self._file = file
self._enable_training_bar = enable_training_bar
self._training_bar = None
self._epoch_bar = None
return
[docs]
@override
def clean_up(self) -> None:
self._clean_epoch_bar()
self._clean_training_bar()
return
[docs]
@functools.singledispatchmethod
@override
def notify(self, event: log_events.Event) -> None:
return super().notify(event)
@notify.register
def _(self, event: log_events.IterateBatchEvent) -> None:
desc = event.source_name.rjust(15)
leave = self._leave and self._training_bar is None
self._epoch_bar = EpochBar(
event.batch_size,
event.n_iter,
event.dataset_size,
leave=leave,
file=self._file,
desc=desc,
position=0 if self._training_bar is None else 1,
)
event.push_updates.append(self._epoch_bar.update)
return super().notify(event)
@notify.register
def _(self, event: log_events.StartTrainingEvent) -> None:
if self._enable_training_bar:
self._training_bar = TrainingBar(
event.start_epoch,
event.end_epoch,
file=self._file,
leave=self._leave,
)
return super().notify(event)
@notify.register
def _(self, event: log_events.StartEpochEvent) -> None:
if self._training_bar is not None:
self._training_bar.update(event.epoch)
return super().notify(event)
@notify.register
def _(self, event: log_events.EndEpochEvent) -> None:
self._clean_epoch_bar()
return super().notify(event)
@notify.register
def _(self, event: log_events.EndTestEvent) -> None:
self._clean_epoch_bar()
return super().notify(event)
@notify.register
def _(self, event: log_events.TerminatedTrainingEvent) -> None:
self._clean_training_bar()
return super().notify(event)
@notify.register
def _(self, event: log_events.EndTrainingEvent) -> None:
self._clean_training_bar()
return super().notify(event)
def _clean_training_bar(self) -> None:
if self._training_bar is not None:
if not self._training_bar.pbar.disable:
self._training_bar.pbar.close()
self._training_bar = None
return
def _clean_epoch_bar(self) -> None:
if self._epoch_bar is not None:
if not self._epoch_bar.pbar.disable:
self._epoch_bar.pbar.close()
self._epoch_bar = None
return