drytorch.lib.loading
Module containing classes nad utilities for batching a dateset.
Functions
|
Calculate the number of batches in a dataset. |
|
Sample a batch of elements from a dataset and transfers them to a device. |
|
Checks if a dataset has a valid length. |
Classes
|
A data-loader class with runtime settings. |
|
Sliceable pseudo-random permutation. |
|
Slice a sequence keeping the reference to it. |
- class DataLoader(dataset: Dataset[Data], batch_size: int, pin_memory: bool | None = None, sampler: Sampler | Iterable[Any] | None = None, n_workers: int = 0)[source]
Bases:
LoaderProtocol[Data]A data-loader class with runtime settings.
This class wraps PyTorch’s DataLoader with additional functionalities.
- dataset
the dataset to load data from.
- Type:
torch.utils.data.dataset.Dataset[drytorch.lib.loading.Data]
- sampler
the sampling strategy for the dataset.
Initialize.
- Parameters:
dataset (Dataset[Data]) – the dataset to load data from.
batch_size (int | None) – number of samples per batch.
pin_memory (bool | None) – pin memory for faster GPU training. Defaults to true when hardware acceleration is available.
sampler (Sampler | Iterable) – defines the strategy to draw samples from the dataset.
n_workers (int) – number of subprocesses for data loading.
- get_loader(inference: bool) DataLoader[Data][source]
Create a DataLoader instance with runtime settings.
- Parameters:
inference (bool) – whether to use inference settings. Default checks torch global state.
- Returns:
A configured PyTorch DataLoader instance.
- Return type:
DataLoader[Data]
- split(split: float = 0.2, shuffle: bool = True, seed: int = 42) tuple[DataLoader[Data], DataLoader[Data]][source]
Split the loader into two.
- Parameters:
- Returns:
A tuple of (DataLoader, DataLoader).
- Raises:
ValueError – if split is not between 0 and 1.
- Return type:
tuple[DataLoader[Data], DataLoader[Data]]