Types and Protocols

Core Types

Throughout drytorch, generic variables must satisfy these constraints:

  • Input and Target: torch.Tensor | MutableSequence[torch.Tensor] | NamedTuple

  • Output: no constraints

  • Data: 2-tuple where both elements follow Input/Target constraints

Note: The notation for the generic variables has been simplified to ignore subtype relationships.

Diagram

The following diagram maps the dependencies between the core interfaces using UML-style notation:

  • Refinement (<|–): Indicates one protocol extends or refines another (e.g., an inheritance relationship).

  • Structural Association (–>): Represents a structural requirement and is often implemented using Dependency Injection.

  • Dependency (..>): Represents a logical dependency that is not enforced by the protocol but often necessary for its implementation.

        classDiagram
direction TB
    class LoaderProtocol["LoaderProtocol[(Input, Target)]"] {
	    batch_size : int | None
        sampler: torch.utils.data.Sampler | Iterable
	    +get_dataset() Data
	    +__iter__() Iterator[(Input, Target) ]
	    +__len__() int
    }

    class SchedulerProtocol {
	    +__call__(base_lr, epoch) float
    }

    class GradientOpProtocol {
	    +__call__(params: Iterable[Parameter])
    }

    class ModuleProtocol["ModuleProtocol[Input, Output]"] {
	    +forward(inputs: Input) Output
    }

    class ObjectiveProtocol["ObjectiveProtocol[Output, Target]"] {
	    +update(outputs: Output, targets: Target)
	    +compute() Mapping | Tensor | None
	    +reset()
    }

    class LossProtocol["LossProtocol[Output, Target]"] {
	    +forward(outputs: Output, targets: Target) Tensor
    }

    class LearningProtocol {
	    optimizer_cls : type[Optimizer]
	    base_lr : float | dict
	    scheduler : SchedulerProtocol
	    gradient_op : GradientOpProtocol
    }

    class CheckpointProtocol {
	    +bind_model(model: ModelProtocol[Any, Any])
	    +bind_optimizer(optimizer: Optimizer)
	    +save()
	    +load(epoch: int)
    }

    class ModelProtocol["ModelProtocol[Input, Output]"] {
	    module : ModuleProtocol
	    epoch : int
	    checkpoint : CheckpointProtocol
	    mixed_precision : bool
	    +__call__(inputs: Input) Output
	    +increment_epoch()
    }

    class MonitorProtocol {
	    model : ModelProtocol[Any, Any]
	    +name : str
	    +computed_metrics : Mapping[str, float]
    }

    class TrainerProtocol["TrainerProtocol[Input, Target, Output]"] {
	    model: ModelProtocol[Input, Output]
	    learning_schema: LearningProtocol
	    objective: LossProtocol[Output, Target]
	    validation: MonitorProtocol | None
	    +terminated : bool
	    +train(num_epochs: int)
	    +update_learning_rate(base_lr, scheduler)
    }

    CheckpointProtocol <--> ModelProtocol : saves / binds with
    ModelProtocol --> ModuleProtocol : wraps
    LearningProtocol --> GradientOpProtocol : contains
    LearningProtocol --> SchedulerProtocol : contains
    MonitorProtocol ..> LoaderProtocol : often gets data from
    MonitorProtocol --> ModelProtocol : evaluates
    MonitorProtocol ..> ObjectiveProtocol : typically according to
    TrainerProtocol --> LearningProtocol : follows
    TrainerProtocol ..> LoaderProtocol : often gets data from
    TrainerProtocol --> LossProtocol : optimizes
    TrainerProtocol --> ModelProtocol : trains
    TrainerProtocol --> MonitorProtocol : can be validated by
    MonitorProtocol <|-- TrainerProtocol : refines
    ObjectiveProtocol <|-- LossProtocol : refines