Solving Core Data Loading Challenges with PyTorch

Solving Core Data Loading Challenges with PyTorch

PyTorch data utilities define concepts of dataset and data loader. They share their general purpose of loading the training data, thus they are easily confused. Why do we need both of them?

Building a deep learning input pipeline involves three challenges:

  • loading the raw data and transforming it into suitable programming structures

  • parallelization of the loading routines (to keep up with the training speed)

  • construction of mini-batch tensors

The first one depends on the raw data format and the loss function. Each problem requires a different solution. In PyTorch, it is handled by the Dataset class that is supposed to be specialized by the researcher. The other two challenges are less varied. Additionally, they are either difficult or simply tedious to implement. For them, PyTorch provides a general solution in form of the DataLoader class.

How does each of these concepts help in solving their corresponding challenges? Please note that the purpose of this article is exclusively to answer this question. This is not a comprehensive guide to PyTorch datasets and data loaders.

Dataset

For the case of simplicity, we are going to focus on the most common map-based Dataset class and ignore the IterableDataset class with a slightly different protocol.

The user must provide its own specialization of this class and override its two methods: __len__ and __getitem__. In theory, Dataset is supposed to mimic a read-only map/dictionary. The documentation states that:

[Dataset] represents a map from (possibly non-integral) indices/keys to data samples

This is a little misleading because in many cases having non-integral keys would not work as intended. Data utilities often assume that a Dataset specialization will accept integer keys between 0 and their length.

This simple example of the Dataset specialization reads all images from the given directory (trivia 1):

class ImageDirectoryDataset(Dataset):
    def __init__(self, path: Path):
        self.image_paths = list(path.iterdir())

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, item):
        image_path = self.image_paths[item]
        return load_as_tensor(image_path)

The usage of such a class resembles the usage of a Python list. Unlike the list though, this class has an acceptable memory footprint. The Dataset stores only the paths to the actual image files and not the image tensors. Given 1 million images the required memory is reduced from possibly hundreds of gigabytes to just tens of megabytes. However, the programming interface stays the same - dataset[5] returns the 5th ready-to-use tensor from the dataset.

An idiomatic Dataset specialization returns a single example for a given index. The mini-batch creation happens in the DataLoader. Note that __getitem__ can return an arbitrary Python object. Those objects must hold certain properties in order to be used with the DataLoader, as discussed below (DataLoader - Collate section). To keep it simple here, the best idea is to always return a tuple of tensors or primitive types. Importantly, this includes the named tuples, which make the code much more readable.

Transforms

A Dataset doesn't need to be passed directly to a DataLoader. It can be also fed to functions that use it to construct a new Dataset. An example could be the ConcatDataset class which accepts a list of Datasets to merge. Another one is a random_split function that allows for quick separation of a validation set out of the training dataset.

List as a Pytorch Dataset

The Dataset protocol has a very simple protocol that requires only __len__ and __getitem__ methods. Additionally, it overrides the + operator for concatenation. All of those properties are already held by the standard python list. It is a useful observation when your entire dataset fits the memory - especially when you write a simple test and need only a handful of examples.

Theoretically, there is a risk that a certain utility would explicitly check for isinstance(x, Dataset) and therefore fail for a list. I find it rather unlikely, but it is an interesting example of why such checks are considered a bad practice. Pytorch utilities do not perform this particular check.

DataLoader

DataLoader is a swiss army knife when it comes to data handling in PyTorch. It is a wrapper for datasets that enables efficient iteration over the batches of examples. I cover only the two most important features, namely:

  • the mini-batch collating

  • parallelization

Collate

The most highlighted feature of the DataLoader is the mini-batch construction, i.e. collating. This process is managed by the collate_fn. Its default behavior is intuitive and you should rely on it whenever possible. It is important to understand the collate logic in order to implement compatible Datasets.

Default Logic

In the default collate logic each tensor and primitive is concatenated into a single batch tensor. For each iterable structure (like a tuple), the logic is being applied recursively. That means that you can use an arbitrary structure of nested iterables given that there is a tensor or a primitive at its bottom.

Let's consider the following object being returned from the __getitem__ method:

class MyExample(NamedTuple):  # `NamedTuple` is important here!
    image: torch.Tensor  # shape = channels x height x width
    label: int
    position: (int, int)

The collate_fn receives a list of such objects and analyzes their structure recursively:

  1. The overall structure is the tuple so we must iterate over all 3 entries.

  2. The first entry is a tensor image - easy just call torch.cat and pass all elements from the list

  3. The second entry is an integer label - equally easy, the solution is the same

  4. The third entry position is a tuple - must recurse

  5. Each entry of the position is an integer - concatenating them separately

The final structure is not flattened and the original type is preserved. The constructed batch would be of the class MyExaple and have the same field names. The only difference is that all integers would become tensors now and the dimensions of each tensor would include the batch dimension. The image would have a 4D shape instead of a 3D one. All the integers (previously scalars) would become 1D tensors. This includes the nested tuple which wouldn't be magically joined (trivia 2).

Note that, in order for this to work, each example returned by the Dataset must have the same structure, e.g. position cannot be (int, int) tuple for some examples and (int, int, int) for others. Additionally, since there is a crude torch.cat underneath, all tensors must have the same shape. It is a good idea to defensively resize all images to the same size in the __getitem__.

Custom Logic

Sometimes there is no way to design the Dataset in such a way that the default collate_fn would yield appropriate batches. In those cases, you would rather implement your own collate_fn and pass it to the DataLoaders constructor to override the default. It's very common for the NLP models which often consume sequences of different lengths.

Let's suppose that we build a word generator. Each character is encoded as some integer value (presumably an index in some character table). If all of the words were exactly the same length a tensor of shape batch x length would suffice. A Dataset would return a single word as a 1D tensor and the default collate logic would properly concatenate them into a 2D tensor. However, if the generated word lengths varied, this approach would no longer work (trivia 3).

To keep this simple let's assume a Dataset which creates the encoded word tensors, but does not care about the length:

import string

class WordDataset(Dataset):
    def __init__(self):
        self.words = ["a", "full", "dictionary", "of", "words"]  # + many more
        self.chars = string.printable  # a list of all interesting characters

    def __length__(self):
        return len(self.words)

    def __getitem__(self, item):
        word = self.words[item]
        encoded = [self.chars.index(c) for c in word]
        return torch.tensor(encoded)

To make it work with the DataLoader we must change the collate mechanism:

def collate_fn(examples):
    max_length = max(len(ex) for ex in examples)
    batch = [
        pad_to_length(ex, max_length)
        for ex in examples
    ]
    return torch.cat(batch)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)

The examples argument of the collate_fn is a list with the length of 2, because such is the requested mini-batch size. Note that there is a DataLoader's parameter drop_last which controls what happens when the overall number of examples is not divisible by the batch size. The default value is True, but it is better to expect that collate_fn can receive a lower number of examples on some calls.

Parallelization

It is important to draw a distinction between the parallelization that happens in the data loading pipeline and the one that happens during the forward and backward passes. I am consciously ignoring the distributed training setup and focusing on single-machine training for simplicity.

To accelerate neural network operations, we usually rely on a SIMD units such as GPUs. SIMD stands for single instruction, multiple data, and this kind of architecture is therefore good at parallelization of a large number of very simple operations. All kinds of multidimensional array operations fit this scheme perfectly.

When it comes to the data input pipelines GPUs are less practical (trivia 4). Some of the operations are not suitable for this kind of parallelization. That's why DataLoader does not parallelize single operations, but it rather tries to load multiple examples simultaneously using different threads/processes (trivia 5). In practice, this means that every instance of __getitem__ (trivia 6) runs on a single thread, but multiple such threads are running at the same time. This is an important simplification because it means that as a programmer you don't have to care about data races in your code.

There are two practical implications of this design:

  • By default, all code within the actual training loop runs on the same thread. The parallelization happens only within torch operations. Only one iteration is being performed at the given time.

  • You can't assign tensors to GPU in the __getitem__. To avoid resource allocation issues PyTorch prevents you from doing this outside of the main process (trivia 7).

The details are not important for the map-style datasets but become important for the iterable-style datasets. I'm more specific in my other article.

Summary

Top notes:

  • read an actual tutorial or documentation examples for the Dataset and DataLoader if you haven't already

  • return only one example from the __getitem__ function rather than the whole batch

  • try to return named tuples from the __getitem__

  • use the default collate logic whenever possible

  • feel free to ignore any of the points above

Trivia

  1. For a more practical implementation of the directory, listing datasets see the ImageFolder class and the DatasetFolder class

  2. Oddly enough, although technically strings are iterables, the default collate_fn does not recurse into them. They are not crudely concatenated either. They are just left as lists of strings that might be useful for some kinds of data or debug purposes. I like to log a name of a file that crashed my training pipeline :P.

  3. Theoretically, it would be sufficient to just pad each word with zeros up to some perceived maximal length and use the default collate logic. Note, however, that it would be very inefficient in practice. Just one absurdly long word in the entire dataset could quadruple the average memory usage. Note that in our example the max_length is computed within the mini-batch, so the problem is not noticeable.

  4. NVidia develops DALI (Data Loading LIbrary) which enables us to perform data loading steps on GPU. It might be worth a look after confirming that data loading is the bottleneck of the training pipeline. It lacks the flexibility of the PyTorch datasets, so I wouldn't recommend it as the first thing to try.

  5. Due to well-known Python GIL issues DataLoader uses multi-processing rather than multi-threading, although I'd consider them just two implementations of the same concept.

  6. Somewhat surprisingly, DataLoader uses multi-processing not only for example loading functions (i.e. __getitem__) but also batch forming collate_fn.

  7. DataLoader accept num_workers parameter. The default value is 0, which indicates that all computations happen in the main thread. This eliminates GPU allocation issues, but also seriously impedes performance.