Source code for drytorch.trackers.plotly

"""Module containing a plotter class using plotly."""

import plotly.graph_objs as go

from typing_extensions import override

from drytorch.trackers import base_classes


__all__ = [
    'PlotlyPlotter',
]


[docs] class PlotlyPlotter(base_classes.BasePlotter[go.Figure]): """Tracker that creates new plots each call (no update) using plotly.""" @override def _plot_metric( self, model_name: str, metric_name: str, **sourced_array: base_classes.NpArray, ) -> go.Figure: data = list[go.Scatter | go.Bar]() for name, log in sourced_array.items(): if log.shape[0] == 1: marker = go.scatter.Marker(symbol=24, size=20) data.append( go.Scatter( x=log[:, 0], y=log[:, 1], mode='markers', marker=marker, name=name, ) ) else: data.append(go.Scatter(x=log[:, 0], y=log[:, 1], name=name)) fig = go.Figure( data=data, layout=go.Layout( title=model_name, xaxis={'title': 'Epoch'}, yaxis={'title': metric_name}, ), ) fig.show() return fig