An essential part of any Deep Learning based system is to structure the data loading pipeline such that it can be seamlessly integrated with your deep learning model. In this tutorial, we will understand the working of data loading functionalities provided by PyTorch and learn to use them in our own deep learning projects effectively.
To learn how to use PyTorch Datasets and DataLoaders, just keep reading.
Looking for the source code to this post?
Jump Right To The Downloads SectionImage Data Loaders in PyTorch
We will discuss the following in detail:
- How to restructure a dataset into training and validation set
- How to load a dataset in PyTorch and utilize in-built PyTorch data augmentations
- How to set-up a PyTorch DataLoaders to efficiently access data samples
Our example Flowers dataset
Our goal is to create a basic data loading pipeline with the help of PyTorch Dataset
and DataLoader
class, which enables us to easily and efficiently access our data samples and pass them to our deep learning model.
For this tutorial, we will be using a dataset of flowers (see Figure 1) that consists of 5 types of flowers:
- Tulips
- Daisy
- Dandelion
- Roses
- Sunflowers
Configuring your development environment
To follow this guide, you need to have the PyTorch deep learning library, matplotlib, OpenCV and imutils packages installed on your system.
Luckily, these packages are extremely easy to install using pip:
$ pip install torch torchvision $ pip install matplotlib $ pip install opencv-contrib-python $ pip install imutils
If you need help configuring your development environment for PyTorch, I highly recommend that you read the PyTorch documentation — PyTorch’s documentation is comprehensive and will have you up and running quickly.
Having problems configuring your development environment?
All that said, are you:
- Short on time?
- Learning on your employer’s administratively locked system?
- Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
- Ready to run the code right now on your Windows, macOS, or Linux system?
Then join PyImageSearch University today!
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
Project structure
Before we proceed, we need to first download the flowers dataset that we will be using in this tutorial.
Start with the “Downloads” section of this tutorial to access the source code and example flowers dataset.
From there, unzip the archive, and you should find the following project directory:
├── build_dataset.py ├── builtin_dataset.py ├── flower_photos │ ├── daisy │ ├── dandelion │ ├── roses │ ├── sunflowers │ └── tulips ├── load_and_visualize.py ├── pyimagesearch │ ├── config.py │ ├── __init__.py
The build_dataset.py
script is responsible for dividing and structuring the dataset into a training and validation set.
Furthermore, builtin_dataset.py
script shows how to directly download and load some commonly used computer vision datasets, such as MNIST, using PyTorch API.
The flower_photos
folder contains the dataset that we will be using. It consists of 5 subdirectories (i.e., daisy, dandelion, roses, sunflowers, tulips) each containing images of the corresponding flower category.
On the other hand, the load_and_visualize.py
script is responsible for loading and accessing the data samples with the help of PyTorch Dataset
and DataLoader
class.
The config.py
file in the pyimagesearch
folder stores information such as parameters, initial settings, configurations for our code.
Note that after downloading the dataset, each image will have a path with the following format, folder_name/class_name/image_id.jpg
. For example, shown below are the paths for some images in the flower dataset.
flower_photos/dandelion/8981828144_4b66b4edb6_n.jpg flower_photos/sunflowers/14244410747_22691ece4a_n.jpg flower_photos/roses/1666341535_99c6f7509f_n.jpg flower_photos/sunflowers/19519101829_46af0b4547_m.jpg flower_photos/dandelion/2479491210_98e41c4e7d_m.jpg flower_photos/sunflowers/3950020811_dab89bebc0_n.jpg
Creating our configuration file
First, we discuss the config.py
file which stores configurations and parameter settings used in the tutorial.
# specify path to the flowers and mnist dataset FLOWERS_DATASET_PATH = "flower_photos" MNIST_DATASET_PATH = "mnist" # specify the paths to our training and validation set TRAIN = "train" VAL = "val" # set the input height and width INPUT_HEIGHT = 128 INPUT_WIDTH = 128 # set the batch size and validation data split BATCH_SIZE = 8 VAL_SPLIT = 0.1
We define the path to our flowers dataset folder on Line 2 and the path for the MNIST dataset on Line 3. In addition to this, we specify the path names for our training and validation set folders on Lines 6 and 7.
On Lines 10 and 11, we define the required height and width for input images which will later enable us to resize our input to a dimension which our model can accept.
Furthermore, we define the batch size and the fraction of the dataset to be held out as validation set on Lines 14 and 15.
Splitting Dataset into Training and Validation set
Here, we discuss how to restructure our flowers dataset into a training and validation set.
Open the build_dataset.py
file in your project directory structure and let’s get started.
# USAGE # python build_dataset.py # import necessary packages from pyimagesearch import config from imutils import paths import numpy as np import shutil import os
We start by importing the required packages on Lines 5-9.
def copy_images(imagePaths, folder): # check if the destination folder exists and if not create it if not os.path.exists(folder): os.makedirs(folder) # loop over the image paths for path in imagePaths: # grab image name and its label from the path and create # a placeholder corresponding to the separate label folder imageName = path.split(os.path.sep)[-1] label = path.split(os.path.sep)[-2] labelFolder = os.path.join(folder, label) # check to see if the label folder exists and if not create it if not os.path.exists(labelFolder): os.makedirs(labelFolder) # construct the destination image path and copy the current # image to it destination = os.path.join(labelFolder, imageName) shutil.copy(path, destination)
Starting on Line 11, we define the copy_images
function. This method accepts a list — imagePaths
(i.e., paths of a set of images) and a destination folder
and copies the input image paths to the destination.
This function will come in handy when we want a set of image paths to be copied to the training or validation folder. Next, we understand each line of this function, in detail.
We first check if the destination folder already exists, and if not we create it on Lines 13 and 14.
On Line 17, we loop over each path in our input imagePaths
list.
For each path
(which is of the form root/class_label/image_id.jpg
) in the list:
- We separately store the
image_id
andclass_label
on Lines 20 and 21, respectively. - On Line 22, we define a folder
labelFolder
in the input destination to store all images from a particularclass_label
. We create thelabelFolder
on Lines 25 and 26 if it does not exist already. - We then construct the destination path (inside
labelFolder
) for an image of a givenimage_id
(Line 30) and copy the current image to it (Line 31).
Once we have defined the copy_images
function, we are ready to understand the main code required to split our dataset into training and validation sets.
# load all the image paths and randomly shuffle them print("[INFO] loading image paths...") imagePaths = list(paths.list_images(config.FLOWERS_DATASET_PATH)) np.random.shuffle(imagePaths) # generate training and validation paths valPathsLen = int(len(imagePaths) * config.VAL_SPLIT) trainPathsLen = len(imagePaths) - valPathsLen trainPaths = imagePaths[:trainPathsLen] valPaths = imagePaths[trainPathsLen:] # copy the training and validation images to their respective # directories print("[INFO] copying training and validation images...") copy_images(trainPaths, config.TRAIN) copy_images(valPaths, config.VAL)
Line 35 loads the paths of all images in the flower dataset, into a list named imagePaths
. We randomly shuffle the image paths with the help of numpy
on Line 36, to ensure the images in the training and validation sets come from all classes uniformly.
Now, we define a fraction of total images we want to keep aside for our validation data.
This is defined by config.VAL_SPLIT
.
A common choice is to set aside 10-20% of your data for validation. On Line 39, we take the confg.VAL_SPLIT
fraction of the total image paths as the validation set length (i.e., valPathsLen
).
This is rounded to the nearest integer since we want the number of images to be a whole number. The remaining fraction is used as the train set length (i.e., trainPathsLen
) on Line 40.
On Lines 41 and 42, we grab the training paths and validations paths from the imagePaths
list.
We then pass this to the copy_images
function (described previously) which takes in the list of train paths and validation paths and copies them to the train
and val
folders defined by the destination folders, that is, config.TRAIN
and config.VAL
, respectively, as shown on Lines 47 and 48.
This structures our file system as shown below. Here, we have separate train
and val
folders which include the training and validation images from different classes, in their respective class folders.
├── train │ ├── daisy │ ├── dandelion │ ├── roses │ ├── sunflowers │ └── tulips └── val ├── daisy ├── dandelion ├── roses ├── sunflowers └── tulips
PyTorch Dataset and DataLoaders
Now that we have divided our dataset in training and validation sets, we are ready to use PyTorch Datasets and DataLoaders to set-up our data loading pipeline.
A PyTorch Dataset provides functionalities to load and store our data samples with the corresponding labels. In addition to this, PyTorch also has an in-built DataLoader
class which wraps an iterable around the dataset enabling us to easily access and iterate over the data samples in our dataset.
Let’s dive a little deeper and understand DataLoaders with the help of code. Basically, our goal is to load our training and val set with the help of PyTorch Dataset
class and access the samples with the help of DataLoader
class.
Open the load_and_visualize.py
file in your project directory.
We start by importing the required packages.
# USAGE # python load_and_visualize. # import necessary packages from pyimagesearch import config from torchvision.datasets import ImageFolder from torch.utils.data import DataLoader from torchvision import transforms import matplotlib.pyplot as plt import torch
Our notable imports (Lines 6-9) include:
ImageFolder
class: Responsible for loading images fromtrain
andval
folders into a PyTorch datasetDataLoader
class: Enables us to wrap an iterable around our dataset so that data samples can be efficiently accessed by our deep learning modeltransforms
: An in-built PyTorch class that provides common image transformationsmatplotlib.pyplot
: for plotting and visualizing images
Now, we define the visualize_batch
function, which will later enable us to plot and visualize sample images from training and validation batches.
def visualize_batch(batch, classes, dataset_type): # initialize a figure fig = plt.figure("{} batch".format(dataset_type), figsize=(config.BATCH_SIZE, config.BATCH_SIZE)) # loop over the batch size for i in range(0, config.BATCH_SIZE): # create a subplot ax = plt.subplot(2, 4, i + 1) # grab the image, convert it from channels first ordering to # channels last ordering, and scale the raw pixel intensities # to the range [0, 255] image = batch[0][i].cpu().numpy() image = image.transpose((1, 2, 0)) image = (image * 255.0).astype("uint8") # grab the label id and get the label from the classes list idx = batch[1][i] label = classes[idx] # show the image along with the label plt.imshow(image) plt.title(label) plt.axis("off") # show the plot plt.tight_layout() plt.show()
Starting on Line 12, the visualize_batch
function takes as input a batch of data samples (i.e., batch
), list of class labels (i.e., classes
) and the dataset_type
, to which the batch belongs, that is, training or validation.
The function loops over the indices in a given batch:
- Grabs an image at the i-th index (Line 25)
- Converts it to channel last format and scales it to conventional image pixel range of
[0-255]
(Lines 26 and 27) - Gets the integer label of the i-th sample and maps it to original class label of the flower dataset with the help of the list
classes
(Lines 30 and 31) - Displays the image with its label (Lines 34 and 35)
While training deep models, we usually want to use data augmentation techniques on images of our training set to improve the generalization ability of our model. PyTorch provides common image transformations that can be used out-of-the-box with the help of the transform class.
We will now look at how this works and incorporates into our data loading pipeline.
# initialize our data augmentation functions resize = transforms.Resize(size=(config.INPUT_HEIGHT, config.INPUT_WIDTH)) hFlip = transforms.RandomHorizontalFlip(p=0.25) vFlip = transforms.RandomVerticalFlip(p=0.25) rotate = transforms.RandomRotation(degrees=15)
On Lines 43-47, we define four image transformations that we want to apply to our images:
resize
: this transform enables us to resize our images to a particular input dimension (i.e.,config.INPUT_HEIGHT
,config.INPUT_WIDTH
) that our deep model can accepthFlip
,vFlip
: It allows us to Horizontally/Vertically flip our images. Note that it takes an argumentp
which defines the probability with which this transform is applied to input images.rotate
: Enables us to rotate our images by a given angle of rotation
Note that PyTorch also provides many other image transformations apart from those mentioned above.
# initialize our training and validation set data augmentation # pipeline trainTransforms = transforms.Compose([resize, hFlip, vFlip, rotate, transforms.ToTensor()]) valTransforms = transforms.Compose([resize, transforms.ToTensor()])
We consolidate the transforms with the help of the Compose
method so that all of them can be applied to our input images one by one. Note that we have another To.Tensor()
transform here which simply converts all input images to PyTorch tensors. In addition, this transform also converts the input PIL Image or numpy.ndarray
which are originally in the range from [0, 255]
, to [0, 1]
.
Here, we define separate transforms for our training and validation set as shown on Lines 51-53. This is because we usually do not use data augmentations on our validation or test set except for transformations like resize
and ToTensor()
which are necessary to convert input data to a format that our deep model can accept.
Now that we have set-up the transformations to be applied, we are ready to load our images into a dataset.
PyTorch provides an in-built ImageFolder
functionality that accepts a root folder and automatically grabs data samples from a given root directory to create a Dataset. Note that ImageFolder
expects images to be arranged in the following format:
root/class_name_1/img_id.png root/class_name_2/img_id.png root/class_name_3/img_id.png root/class_name_4/img_id.png
This allows it to identify all unique class names and map them to integer class labels. In addition, ImageFolder
also accepts transforms (as discussed before) that we want to apply to our input images while loading them.
# initialize the training and validation dataset print("[INFO] loading the training and validation dataset...") trainDataset = ImageFolder(root=config.TRAIN, transform=trainTransforms) valDataset = ImageFolder(root=config.VAL, transform=valTransforms) print("[INFO] training dataset contains {} samples...".format( len(trainDataset))) print("[INFO] validation dataset contains {} samples...".format( len(valDataset)))
On Lines 57-60, we use ImageFolder
to create PyTorch Datasets for training and validation sets, respectively. Note that each PyTorch dataset has a __len__
method that enables us to get the number of samples in the dataset as shown on Lines 61-64.
Also, each Dataset has a __getitem__
method that enables us to directly index into the samples and grab a particular data point.
Suppose we want to check the type of the i-th data sample in the dataset. We can simply index into our dataset as trainDataset[i]
and access the data point which is a tuple. This is because each of the data samples in our dataset is a tuple of the form (image, label
).
Now, we are ready to create a DataLoader for our datasets.
# create training and validation set dataloaders print("[INFO] creating training and validation set dataloaders...") trainDataLoader = DataLoader(trainDataset, batch_size=config.BATCH_SIZE, shuffle=True) valDataLoader = DataLoader(valDataset, batch_size=config.BATCH_SIZE)
A DataLoader accepts a PyTorch dataset and outputs an iterable which enables easy access to data samples from the dataset.
On Lines 68-70, we pass our training and validation datasets to the DataLoader
class.
A PyTorch DataLoader accepts a batch_size
so that it can divide the dataset into chunks of samples.
The samples in each chunk or batch can then be parallelly processed by our deep model. Furthermore, we can also decide if we want to shuffle our samples before passing it to the deep model which is usually required for optimal learning and convergence of batch gradient based optimization approaches.
Lines 68-70 returns two iterables (i.e., trainDataLoader
and valDataLoader
).
# grab a batch from both training and validation dataloader trainBatch = next(iter(trainDataLoader)) valBatch = next(iter(valDataLoader)) # visualize the training and validation set batches print("[INFO] visualizing training and validation batch...") visualize_batch(trainBatch, trainDataset.classes, "train") visualize_batch(valBatch, valDataset.classes, "val")
We convert the trainDataLoader
and valDataLoader
iterable to a python iterator using the iter()
method as shown on Lines 73 and 74. This allows us to simply iterate through batches of training or validation with the help of the next()
method.
Finally, we visualize the training and validation batch with the help of visualize_batch
function.
Lines 78 and 79 visualize the trainBatch
and valBatch
with the help of the visualize_batch
method giving the following output. Figures 3 and 4 show sample images from the training and validation batches, respectively.
Built-in Datasets
In the previous sections of this PyTorch Data Loader tutorial, we learned to download a custom dataset, structure it, load it as a PyTorch dataset and access its samples with the help of DataLoaders.
In addition to this, PyTorch also provides a simple API that can be used to directly download and load images from some commonly used datasets in computer vision. These include datasets like MNIST, CIFAR-10, CIFAR-100, CelebA, etc.
We will now see how we can easily access these datasets and use them for our own projects. For the purpose of this tutorial, we use the MNIST dataset.
Let’s first open the builtin_dataset.py
file in our project directory.
# USAGE # python builtin_dataset.py # import necessary packages from pyimagesearch import config from torchvision.datasets import MNIST from torch.utils.data import DataLoader from torchvision import transforms import matplotlib.pyplot as plt
We start by importing the necessary packages on Lines 5-9. The torchvision.datasets
module provides us with the functionality to directly load common, widely used datasets.
On Line 6, we import the MNIST
dataset from this module.
Our other notable imports include the PyTorch DataLoader
class (Line 7), the transforms module from torchvision (Line 8), and the matplotlib library (Line 9) for visualization.
def visualize_batch(batch, classes, dataset_type): # initialize a figure fig = plt.figure("{} batch".format(dataset_type), figsize=(config.BATCH_SIZE, config.BATCH_SIZE)) # loop over the batch size for i in range(0, config.BATCH_SIZE): # create a subplot ax = plt.subplot(2, 4, i + 1) # grab the image, convert it from channels first ordering to # channels last ordering, and scale the raw pixel intensities # to the range [0, 255] image = batch[0][i].cpu().numpy() image = image.transpose((1, 2, 0)) image = (image * 255.0).astype("uint8") # grab the label id and get the label from the classes list idx = batch[1][i] label = classes[idx] # show the image along with the label plt.imshow(image[..., 0], cmap="gray") plt.title(label) plt.axis("off") # show the plot plt.tight_layout() plt.show()
Next, we define the visualize_batch
function that helps us in visualizing samples from a batch.
This function is similar to the visualize_batch
function we defined earlier in the load_and_visualize.py
file. The only difference here is on Line 33, where we plot the image in cmap="gray"
mode since MNIST consists of single channel, grayscale images, which is in contrast to 3-channel RGB images in the flower dataset.
# define the transform transform = transforms.Compose([transforms.ToTensor()]) # initialzie the training and validation dataset print("[INFO] loading the training and validation dataset...") trainDataset = MNIST(root=config.MNIST_DATASET_PATH, train=True, download=True, transform=transform) valDataset = MNIST(root=config.MNIST_DATASET_PATH, train=False, download=True, transform=transform)
We define our transforms on Line 42. On Lines 46-49, we use the torchvision.datasets
module to directly download the MNIST training and validation set and load it as PyTorch datasets trainDataset
and valDataset
. Here, we need to provide the following arguments:
root
: the root directory where we want to save the datasettrain
: indicates if we want to load the training set (iftrain=True
) or test set (iftrain=False
)download
: responsible for automatically downloading the datasettransforms
: the image transformations to be applied to the input images
# create training and validation set dataloaders print("[INFO] creating training and validation set dataloaders...") trainDataLoader = DataLoader(trainDataset, batch_size=config.BATCH_SIZE, shuffle=True) valDataLoader = DataLoader(valDataset, batch_size=config.BATCH_SIZE) # grab a batch from both training and validation dataloader trainBatch = next(iter(trainDataLoader)) valBatch = next(iter(valDataLoader)) # visualize the training set batch print("[INFO] visualizing training batch...") visualize_batch(trainBatch, trainDataset.classes, "train") # visualize the validation set batch print("[INFO] visualizing validation batch...") visualize_batch(valBatch, valDataset.classes, "val")
We create DataLoaders for training and validation datasets, that is, trainDataLoader
and valDataLoader
on Lines 53-55. Next, we get batches from the training and validation DataLoaders and visualize sample images on Lines 58-67, as explained earlier.
Now you understand how to use PyTorch DataLoaders with built-in PyTorch datasets.
What's next? I recommend PyImageSearch University.
30+ total classes • 39h 44m video • Last updated: 12/2021
★★★★★ 4.84 (128 Ratings) • 3,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
- ✓ 30+ courses on essential computer vision, deep learning, and OpenCV topics
- ✓ 30+ Certificates of Completion
- ✓ 39h 44m on-demand video
- ✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
- ✓ Pre-configured Jupyter Notebooks in Google Colab
- ✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
- ✓ Access to centralized code repos for all 500+ tutorials on PyImageSearch
- ✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
- ✓ Access on mobile, laptop, desktop, etc.
Summary
In this tutorial, we learned how to structure our data loading pipeline with the help of built-in PyTorch functionalities. Specifically, we learned how PyTorch Datasets and DataLoaders could efficiently load and access data samples from our dataset.
Our goal was to structure the flowers dataset by splitting it into training and testing sets and load the data samples with the help of PyTorch Datasets.
We discussed how to efficiently access the data samples using PyTorch DataLoaders while applying various data augmentations.
Finally, we also learned about some common built-in PyTorch datasets that can be directly loaded and used in our deep learning projects.
After following the tutorial, we built a data loading pipeline that can be seamlessly integrated with and used to train any deep learning model at hand. Congratulations!
Citation Information
Chandhok, S. “Image Data Loaders in PyTorch,” PyImageSearch, 2021, https://hcl.pyimagesearch.com/2021/10/04/image-data-loaders-in-pytorch/
@article{Chandhok_2021,
author = {Shivam Chandhok},
title = {Image Data Loaders in {PyTorch}},
journal = {PyImageSearch},
year = {2021},
note = {https://hcl.pyimagesearch.com/2021/10/04/image-data-loaders-in-pytorch/}
}
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
Comment section
Hey, Adrian Rosebrock here, author and creator of PyImageSearch. While I love hearing from readers, a couple years ago I made the tough decision to no longer offer 1:1 help over blog post comments.
At the time I was receiving 200+ emails per day and another 100+ blog post comments. I simply did not have the time to moderate and respond to them all, and the sheer volume of requests was taking a toll on me.
Instead, my goal is to do the most good for the computer vision, deep learning, and OpenCV community at large by focusing my time on authoring high-quality blog posts, tutorials, and books/courses.
If you need help learning computer vision and deep learning, I suggest you refer to my full catalog of books and courses — they have helped tens of thousands of developers, students, and researchers just like yourself learn Computer Vision, Deep Learning, and OpenCV.
Click here to browse my full catalog.