PyTorch data utilities define concepts of dataset and data loader. They share their general purpose of loading the training data, thus they are easily confused. Why do we need both of them?
Building a deep learning input pipeline involves three challenges:
loading the raw data and transforming it into suitable programming structures
parallelization of the loading routines (to keep up with the training speed)
construction of mini-batch tensors
The first one depends on the raw data format and the loss function. Each problem requires a different solution. In PyTorch, it is handled by the Dataset
class that is supposed to be specialized by the researcher. The other two challenges are less varied. Additionally, they are either difficult or simply tedious to implement. For them, PyTorch provides a general solution in form of the DataLoader
class.
How does each of these concepts help in solving their corresponding challenges? Please note that the purpose of this article is exclusively to answer this question. This is not a comprehensive guide to PyTorch datasets and data loaders.
Dataset
For the case of simplicity, we are going to focus on the most common map-based Dataset
class and ignore the IterableDataset
class with a slightly different protocol.
The user must provide its own specialization of this class and override its two methods: __len__
and __getitem__
. In theory, Dataset
is supposed to mimic a read-only map/dictionary. The documentation states that:
[Dataset] represents a map from (possibly non-integral) indices/keys to data samples
This is a little misleading because in many cases having non-integral keys would not work as intended. Data utilities often assume that a Dataset
specialization will accept integer keys between 0 and their length.
This simple example of the Dataset
specialization reads all images from the given directory (trivia 1):
class ImageDirectoryDataset(Dataset):
def __init__(self, path: Path):
self.image_paths = list(path.iterdir())
def __len__(self):
return len(self.image_paths)
def __getitem__(self, item):
image_path = self.image_paths[item]
return load_as_tensor(image_path)
The usage of such a class resembles the usage of a Python list. Unlike the list though, this class has an acceptable memory footprint. The Dataset
stores only the paths to the actual image files and not the image tensors. Given 1 million images the required memory is reduced from possibly hundreds of gigabytes to just tens of megabytes. However, the programming interface stays the same - dataset[5]
returns the 5th ready-to-use tensor from the dataset.
An idiomatic Dataset
specialization returns a single example for a given index. The mini-batch creation happens in the DataLoader
. Note that __getitem__
can return an arbitrary Python object. Those objects must hold certain properties in order to be used with the DataLoader
, as discussed below (DataLoader - Collate section). To keep it simple here, the best idea is to always return a tuple of tensors or primitive types. Importantly, this includes the named tuples, which make the code much more readable.
Transforms
A Dataset
doesn't need to be passed directly to a DataLoader
. It can be also fed to functions that use it to construct a new Dataset
. An example could be the ConcatDataset
class which accepts a list of Dataset
s to merge. Another one is a random_split
function that allows for quick separation of a validation set out of the training dataset.
List as a Pytorch Dataset
The Dataset protocol has a very simple protocol that requires only __len__
and __getitem__
methods. Additionally, it overrides the +
operator for concatenation. All of those properties are already held by the standard python list
. It is a useful observation when your entire dataset fits the memory - especially when you write a simple test and need only a handful of examples.
Theoretically, there is a risk that a certain utility would explicitly check for isinstance(x, Dataset)
and therefore fail for a list
. I find it rather unlikely, but it is an interesting example of why such checks are considered a bad practice. Pytorch utilities do not perform this particular check.
DataLoader
DataLoader
is a swiss army knife when it comes to data handling in PyTorch. It is a wrapper for datasets that enables efficient iteration over the batches of examples. I cover only the two most important features, namely:
the mini-batch collating
parallelization
Collate
The most highlighted feature of the DataLoader
is the mini-batch construction, i.e. collating. This process is managed by the collate_fn
. Its default behavior is intuitive and you should rely on it whenever possible. It is important to understand the collate logic in order to implement compatible Dataset
s.
Default Logic
In the default collate logic each tensor and primitive is concatenated into a single batch tensor. For each iterable structure (like a tuple), the logic is being applied recursively. That means that you can use an arbitrary structure of nested iterables given that there is a tensor or a primitive at its bottom.
Let's consider the following object being returned from the __getitem__
method:
class MyExample(NamedTuple): # `NamedTuple` is important here!
image: torch.Tensor # shape = channels x height x width
label: int
position: (int, int)
The collate_fn
receives a list of such objects and analyzes their structure recursively:
The overall structure is the tuple so we must iterate over all 3 entries.
The first entry is a tensor
image
- easy just calltorch.cat
and pass all elements from the listThe second entry is an integer
label
- equally easy, the solution is the sameThe third entry
position
is a tuple - must recurseEach entry of the
position
is an integer - concatenating them separately
The final structure is not flattened and the original type is preserved. The constructed batch would be of the class MyExaple
and have the same field names. The only difference is that all integers would become tensors now and the dimensions of each tensor would include the batch dimension. The image would have a 4D shape instead of a 3D one. All the integers (previously scalars) would become 1D tensors. This includes the nested tuple which wouldn't be magically joined (trivia 2).
Note that, in order for this to work, each example returned by the Dataset
must have the same structure, e.g. position
cannot be (int, int)
tuple for some examples and (int, int, int)
for others. Additionally, since there is a crude torch.cat
underneath, all tensors must have the same shape. It is a good idea to defensively resize all images to the same size in the __getitem__
.
Custom Logic
Sometimes there is no way to design the Dataset
in such a way that the default collate_fn
would yield appropriate batches. In those cases, you would rather implement your own collate_fn
and pass it to the DataLoader
s constructor to override the default. It's very common for the NLP models which often consume sequences of different lengths.
Let's suppose that we build a word generator. Each character is encoded as some integer value (presumably an index in some character table). If all of the words were exactly the same length a tensor of shape batch x length
would suffice. A Dataset
would return a single word as a 1D tensor and the default collate logic would properly concatenate them into a 2D tensor. However, if the generated word lengths varied, this approach would no longer work (trivia 3).
To keep this simple let's assume a Dataset
which creates the encoded word tensors, but does not care about the length:
import string
class WordDataset(Dataset):
def __init__(self):
self.words = ["a", "full", "dictionary", "of", "words"] # + many more
self.chars = string.printable # a list of all interesting characters
def __length__(self):
return len(self.words)
def __getitem__(self, item):
word = self.words[item]
encoded = [self.chars.index(c) for c in word]
return torch.tensor(encoded)
To make it work with the DataLoader
we must change the collate mechanism:
def collate_fn(examples):
max_length = max(len(ex) for ex in examples)
batch = [
pad_to_length(ex, max_length)
for ex in examples
]
return torch.cat(batch)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
The examples
argument of the collate_fn
is a list with the length of 2
, because such is the requested mini-batch size. Note that there is a DataLoader
's parameter drop_last
which controls what happens when the overall number of examples is not divisible by the batch size. The default value is True
, but it is better to expect that collate_fn
can receive a lower number of examples on some calls.
Parallelization
It is important to draw a distinction between the parallelization that happens in the data loading pipeline and the one that happens during the forward and backward passes. I am consciously ignoring the distributed training setup and focusing on single-machine training for simplicity.
To accelerate neural network operations, we usually rely on a SIMD units such as GPUs. SIMD stands for single instruction, multiple data, and this kind of architecture is therefore good at parallelization of a large number of very simple operations. All kinds of multidimensional array operations fit this scheme perfectly.
When it comes to the data input pipelines GPUs are less practical (trivia 4). Some of the operations are not suitable for this kind of parallelization. That's why DataLoader
does not parallelize single operations, but it rather tries to load multiple examples simultaneously using different threads/processes (trivia 5). In practice, this means that every instance of __getitem__
(trivia 6) runs on a single thread, but multiple such threads are running at the same time. This is an important simplification because it means that as a programmer you don't have to care about data races in your code.
There are two practical implications of this design:
By default, all code within the actual training loop runs on the same thread. The parallelization happens only within
torch
operations. Only one iteration is being performed at the given time.You can't assign tensors to GPU in the
__getitem__
. To avoid resource allocation issues PyTorch prevents you from doing this outside of the main process (trivia 7).
The details are not important for the map-style datasets but become important for the iterable-style datasets. I'm more specific in my other article.
Summary
Top notes:
read an actual tutorial or documentation examples for the
Dataset
andDataLoader
if you haven't alreadyreturn only one example from the
__getitem__
function rather than the whole batchtry to return named tuples from the
__getitem__
use the default collate logic whenever possible
feel free to ignore any of the points above
Trivia
For a more practical implementation of the directory, listing datasets see the
ImageFolder
class and theDatasetFolder
classOddly enough, although technically strings are iterables, the default
collate_fn
does not recurse into them. They are not crudely concatenated either. They are just left as lists of strings that might be useful for some kinds of data or debug purposes. I like to log a name of a file that crashed my training pipeline :P.Theoretically, it would be sufficient to just pad each word with zeros up to some perceived maximal length and use the default collate logic. Note, however, that it would be very inefficient in practice. Just one absurdly long word in the entire dataset could quadruple the average memory usage. Note that in our example the
max_length
is computed within the mini-batch, so the problem is not noticeable.NVidia develops DALI (Data Loading LIbrary) which enables us to perform data loading steps on GPU. It might be worth a look after confirming that data loading is the bottleneck of the training pipeline. It lacks the flexibility of the PyTorch datasets, so I wouldn't recommend it as the first thing to try.
Due to well-known Python GIL issues
DataLoader
uses multi-processing rather than multi-threading, although I'd consider them just two implementations of the same concept.Somewhat surprisingly,
DataLoader
uses multi-processing not only for example loading functions (i.e.__getitem__
) but also batch formingcollate_fn
.DataLoader
acceptnum_workers
parameter. The default value is 0, which indicates that all computations happen in the main thread. This eliminates GPU allocation issues, but also seriously impedes performance.