PyTorch Datasets: A Guide to Loading and Using Popular Datasets

A Comprehensive Guide To Datasets With PyTorch

Undeniably, the crucial component of any project(small, large, or any real-world project) is the data we work on. Especially for machine learning and data science projects, gathering and cleaning datasets is the most significant step of the project pipeline.

Data collection and preprocessing of the data should be performed correctly so that the model learns and adapts to the intricate features of the dataset.

When it comes to collecting datasets, there are many public communities such as Kaggle, HuggingFace, and other government repositories from which we can download the datasets and use them in our projects.

Google Big Query also has a lot of data that we can download from.

One thing that is common between all these dataset repositories is that we need to download them and preprocess them manually. However, there are a couple of frameworks that allow us to load the dataset directly in our environment without having to download them.

PyTorch provides a wide range of datasets for machine learning tasks, including computer vision and natural language processing. The torchvision module offers popular datasets like CelebA, CIFAR, COCO, MNIST, and ImageNet. With the help of the DataLoader and Dataset classes, you can efficiently load and utilize these datasets in your projects. This guide walks you through the process of importing and loading datasets, using the MNIST dataset as an example.

In this post, we are going to talk about the Pytorch datasets.

Introduction to PyTorch and Its Dataset Categories

PyTorch is an open-source machine learning framework that supports many machine learning tasks such as computer vision, and natural language processing, assisting in building deep learning models in the form of loss functions, optimizers, and other hyperparameters.

It also provides many state-of-the-art datasets and data loaders to load the datasets into our environment.

PyTorch Datasets

Primarily, the datasets in PyTorch are categorized as follows.

  • torchaudio
  • torchvision
  • torchtext

We are going to look at the datasets available in the torchvision module.

TorchVision: A Module for Computer Vision Tasks

Torchvision is a module in Pytorch specifically used for image-related tasks like computer vision tasks and classification.

We can find the following datasets in the image category.

CelebA dataset

The popular one on the list is the celeba dataset, which contains the facial images of celebrities across the world. It is used for face attribute recognition, detection, and landmark localization. Recently, it has been used as the primary training dataset for GANS, to generate fake human faces.

It can be loaded using the below syntax.

torchvision.datasets.CelebA(root: str, split: str = 'train', target_type: Union[List[str], str] = 'attr', transform: Union[Callable, NoneType] = None, target_transform: Union[Callable, NoneType] = None, download: bool = False)


The Canadian Institute For Advanced Research dataset is used for training computer vision algorithms. It has two variants – cifar 100 and cifar 10. The cifar 10 is a subset of the cifar 100 dataset with 10 classes – airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. It is the go-to dataset for researchers to build models, especially computer vision, due to its small size.

The cifar 10 class follows the below syntax.

torchvision.datasets.CIFAR10(root: str, train: bool = True, transform: Union[Callable, NoneType] = None, target_transform: Union[Callable, NoneType] = None, download: bool = False)

COCO Dataset

The Common Objects in COntext is a benchmark dataset for many computer vision-related applications and research. Developed by Microsoft, this dataset has been used for image recognition, object detection, segmentation, and image captioning tasks.

torchvision.datasets.CocoCaptions(root: str, annFile: str, transform: Union[Callable, NoneType] = None, target_transform: Union[Callable, NoneType] = None, transforms: Union[Callable, NoneType] = None)


This is another dataset that every machine learning student is aware of. It was named after the institute that developed it – The modified National Institute of Standards and Technology. It is a collection of handwritten digits(0-9) and is used for classification.


While the MNIST dataset contains handwritten digits, the fashion mnist contains articles of clothing and accessories. It is also used for classification.

torchvision.datasets.FashionMNIST(root: str, train: bool = True, transform: Union[Callable, NoneType] = None, target_transform: Union[Callable, NoneType] = None, download: bool = False)


The EMNIST is an extension of MNIST, containing handwritten digits and alphabets.

torchvision.datasets.EMNIST(root: str, split: str, **kwargs)


The Imagenet dataset is another state-of-the-art dataset used for image-related tasks like object recognition. It contains more than 14 million images under different datasets. Of course, we would be only using a subset of this database such as the tiny imagenet dataset.

torchvision.datasets.ImageNet(root: str, split: str = 'train', download: Union[str, NoneType] = None, **kwargs)

That is the introduction to the important datasets in the torchvision module. Now, let us take a look at two crucial functions that can be used to load the datasets in our environment.

Efficient Dataset Loading with the DataLoader Class

The dataloader class, just as the name suggests is used to efficiently load the datasets. It is a PyTorch utility class, hence it can be found under the module.

It is used in combination with the Dataset class, which also can be used to access a specific dataset.

Accessing Datasets using the Dataset Class

The Dataset class is a concatenation of all the datasets, meaning we can import the required dataset from this class. It is found in the module same as the dataloader class. Using these two classes, we can import the dataset and load it in the environment by following the given syntax.

imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader =,

Here, we are importing the ImageNet dataset from the datasets class and using the data loader to load the dataset into our environment with a batch size of 4 that allows for shuffling.

Step-by-Step Guide: Importing and Loading the MNIST Dataset

import torch
import torchvision
from torchvision import datasets, transforms

Firstly, we import the torch and torchvision modules. Then, we import the datasets and transform modules from torchvision.

Next, we’d have to convert the transforms to Tensors(the primary datatype of the PyTorch library). For this, we use the below code snippet. The tensors are also normalized using the Normalize method.

transform = transforms.Compose([
    transforms.Normalize((0.1307,), (0.3081,))

Moving on, we define the train and test sets of the MNIST dataset.

trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
Downloading the MNIST dataset
Downloading the MNIST dataset

In the second line, we are defining the test set. Hence, the train parameter is set to False.

Now, we need to define the data loaders for the MNIST train and test sets.

batch_size = 64
trainloader =, batch_size=batch_size, shuffle=True)
testloader =, batch_size=batch_size, shuffle=False)

We have defined the batch size to be 64. The trainloader is used to load the trainset with the given batch size and allows shuffling. We do the same with the test set, but keep the shuffling set to False.

Let’s visualize!

import matplotlib.pyplot as plt
import numpy as np
dataiter = iter(trainloader)
images, labels = dataiter.__next__()
for i in range(4):
    plt.subplot(2, 2, i+1)
    plt.imshow(images[i].numpy().reshape(28, 28), cmap='gray')
    plt.title(f"Label: {labels[i]}")

We have imported the matplotlib and numpy libraries in the first two lines. Then, we convert the trainloader variable as an iterable so that we can iterate through the training images. This iterable is stored in a new variable called dataiter. The next method is used to get the immediate item in the iterable. Lastly, in the for loop, we are displaying the images along with their labels.

Datasets - MNIST sample
Datasets – MNIST sample

In the same way, we can load any dataset from the module datasets.datasetname(for example, datasets.CIFAR10).


We have discussed the need for the collection of data especially in machine learning and data science-related projects as most of the models depend on the accuracy of data. Although there are many public repositories from which we can download datasets, scientists are interested in loading the datasets from Python frameworks such as PyTorch, HuggingFace, and Keras.

We have discussed a few popular and important datasets from the PyTorch library under the category- Vision. These datasets can be loaded as a class (as specified in the documentation) and can be used for research or a project.

We have also gone through a few crucial classes of the PyTorch library that assist in loading and importing the datasets from the torchvision module. Later, we saw an example of how to import and load the MNIST dataset with the help of these classes.

How can you use PyTorch datasets to create innovative and impactful machine-learning projects?


PyTorch Datasets