IMAGE AUGMENTATION FOR CREATING DATASETS USING PYTORCH FOR DUMMIES BY A DUMMY

Anush Somasundaram
9 min readMar 30, 2023

Most models need a good chunk of data to be trained, tested, and validated. This data isn’t always readily available to suit the needs of the model you’ve created. Data augmentation is a brilliant way to get more out of the data you already have available.

Recently, I had the need to build an optical character recognition model for a script that was unusual and hadn’t been worked on very extensively before, so I had to get handwritten samples, but I couldn’t possibly get enough handwritten samples by asking friends for help or outsourcing the job.

In comes the magic of augmentation, instead of drawing every possible variation of a character in the script, you hand-draw a few of them and perform augmentation. We will be able to get a variety of images from one single image using image augmentation.

PyTorch has a module available called torchvision.transforms that lets us augment images in different ways, allowing us to create multiple images from a single image, which in turn helps us create a more dense dataset. First, we’ll go through some basic augmentations using torchvision.transforms, and at the end, we’ll finish off with a script that lets you parse through hierarchical folders of images, augment them, and then store them in the same hierarchical structure. The torchvision.transforms module makes performing augmentation super smooth and simple.

All this stuff can also be done with the OpenCV library, but we’re not going to use that.

Before we move on, make sure you have all these libraries installed, imported and ready to go. (using an ipython notebook)

import PIL
import torch
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import torchvision.transforms as T

We’re obviously gonna need an image to work with, otherwise the entire article would be redundant. Since I’d like to think I have a little bit of joy in my life, I’m not gonna use a handwritten letter or something boring instead, we’re going to use a picture of a cute golden retriever.

Let’s open the image using the PIL.Image library and store it in a variable called orig_img.

orig_img = PIL.Image.open(Path('dog.png'))

Now let’s show the image we’re going to be working with using the matplotlib.pyplot library ( plt.imshow(path) ).

np.asarray(orig_img).shape
plt.imshow(orig_img)
orig_img.show()

Now Let’s get to messing around and dismembering this image.

1. Resize the image

Sometimes the images are too big to process, so we resize them so that our models can be trained effectively. Keep in mind that this process has to be done to all images, even after augmentation, as convolutional neural networks only train on image data of the same size. On the other hand, recurrent neural networks are something you can have a blast with, they’ll train on any size.

resize_transform = T.Resize((32,32))
resized_image=resize_transform(orig_img)
plt.imshow(resized_image)

2. Converting the image to grayscale

Converting the image to grayscale representations of themselves is used to extract descriptors from the image, converting the image to grayscale is less computationally taxing than processing the image in colour. OpenCV has also got a super easy way of converting images to grayscale, I’d say it’s about the same amount of effort, but that’s for another article.

grayscale_transform = T.Grayscale(3)
grayscaled_image=grayscale_transform(orig_img)
plt.imshow(grayscaled_image)

3. Rotating the image

Image rotation is a common image processing routine, this gives the model various features to extract and match. It also helps increase the size of the training data by creating different perspectives for the model to train on. I’ve performed 3 rotations at 45°, 65° and 85° respectively, you can set it to random or give your own angles of rotation.

random_rotation_transformation_45 = T.RandomRotation(45)
random_rotation_transformation_85 = T.RandomRotation(85)
random_rotation_transformation_65 = T.RandomRotation(65)
plt.imshow(random_rotation_transformation_45(orig_img))
plt.imshow(random_rotation_transformation_85(orig_img))
plt.imshow(random_rotation_transformation_65(orig_img))

4 . Random crop

In real time situations, objects are not always wholly visible in the image or at the same scale as our training data, so when we train our model, we want to add a little bit of variety to the data set by just training the model on parts of the images, so that it can make an accurate prediction even when the entire object is not visible (yeah, I know I can’t call a dog an object, but this is not the time to be politically correct), hence we crop the image and create a new image from the cropped portion.

size_of_crop = 560
random_crops = T.RandomCrop(size = size_of_crop)
required_image = random_crops(orig_img)

plt.imshow(required_image)

5. Gaussian Blur

Gaussian blurring is used to remove noise and speckles from the image. The gaussian blurring is done by smoothing the vector of the image using a gaussian function. It helps in removing high-frequency components that cause false features to be trained. We can control the level of smoothing by changing the sigma value while passing the parameter to the function.

gausian_blur_transformation_13 = T.GaussianBlur(kernel_size = (7,13), sigma = (6 , 7))
gausian_blur_transformation_56 = T.GaussianBlur(kernel_size = (7,13), sigma = (2 , 9))
gausian_blurred_image_13 = gausian_blur_transformation_13(orig_img)
gausian_blurred_image_56 = gausian_blur_transformation_56(orig_img)
plt.imshow(gausian_blurred_image_13)
plt.imshow(gausian_blurred_image_56)

6. Gaussian Noise

Adding Gaussian noise to the image will help the image have strategic variations in results in the training data. The Gaussian noise function provided by torchvision.transforms will help create noise with a Gaussian distribution in the image. The gaussian noise function in torchvision.transforms will only work with tensors, so as you can see below, we have to create a little wrapper function to convert the image to a tensor and back. We can change the level of noise by changing the value of the noise parameter passed to the function.

def addnoise(input_image, noise_factor = 0.3):
inputs = T.ToTensor()(input_image)
noise = inputs + torch.rand_like(inputs) * noise_factor
noise = torch.clip (noise,0,1.)
output_image = T.ToPILImage()
image = output_image(noise)
return image

gausian_image_3 = addnoise(orig_img)
gausian_image_6 = addnoise(orig_img,0.6)
gausian_image_9 = addnoise(orig_img,0.9)
plt.imshow(gausian_image_3)
plt.imshow(gausian_image_6)
plt.imshow(gausian_image_9)

7. Colour Jitter

The colour jitter function from torchvision.transforms help vary the brightness, contrast, hue and saturation of the image, creating variations in these aspects of the image allows the model to find more features in it. This also allows the object to be recognised in various surroundings. Three iterations of colour jitter are performed on this image, you can make more by changing the values that are passed as parameters to the function.

colour_jitter_transformation_1 = T.ColorJitter(brightness=(0.5,1.5),contrast=(3),saturation=(0.3,1.5),hue=(-0.1,0.1))

colour_jitter_transformation_2 = T.ColorJitter(brightness=(0.7),contrast=(6),saturation=(0.9),hue=(-0.1,0.1))

colour_jitter_transformation_3 = T.ColorJitter(brightness=(0.5,1.5),contrast=(2),saturation=(1.4),hue=(-0.1,0.5))



colour_jitter_image_1 = colour_jitter_transformation_1(orig_img)
colour_jitter_image_2 = colour_jitter_transformation_2(orig_img)
colour_jitter_image_3 = colour_jitter_transformation_3(orig_img)
plt.imshow(colour_jitter_image_1)
plt.imshow(colour_jitter_image_2)
plt.imshow(colour_jitter_image_3)

8. Random Invert

The random invert function inverts the given image randomly with a given probability. This function helps create variety in the data set.

transform = T.RandomInvert(p = 0.25)                                                                                                                                                                                                                  
inverted_img = transform(orig_img)
plt.imshow(inverted_img)

Script for efficient augmentation of classified images

Now we’ve seen all of these transformations individually, but it wouldn’t make sense to perform all of these augmentations individually on every single image, unless you like pain, it’s not worth it.

The script given below goes through files with training images, creates all the augmentations of the images, and then stores them in another folder in the same file structure. Make sure you have a folder of folders with the training images (with the training label of each type as the name of the folder) . Create a new folder for the augmented images. Specify the path of the folder with images to be augmented in the variable master_dataset and the path of the folder where the augmented images will be stored in the variable augmented dataset. Then run the script, Et voilà, you have all the augmented images in the new folder.

Before running the script the specified folder for augmented images is empty:

The script:

# This script aims to create augmented images from one image to create a larger dataset for our cnn model
# The augmentation this script will perform on each object is
# orig_img,grayscaled_image,random_rotation_transformation_45_image,random_rotation_transformation_65_image,random_rotation_transformation_85_image,gausian_blurred_image_13_image,gausian_blurred_image_56_image,gausian_image_3,gausian_image_6,gausian_image_9,colour_jitter_image_1,colour_jitter_image_2,colour_jitter_image_3

#call the function creating file with augmented image give path of dataset and path of folder where you want the augmented images to be stored

import PIL
import torch
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import torchvision.transforms as T
import os

#torch.transforms

#grayscale
grayscale_transform = T.Grayscale(3)

#random rotation
random_rotation_transformation_45 = T.RandomRotation(45)
random_rotation_transformation_85 = T.RandomRotation(85)
random_rotation_transformation_65 = T.RandomRotation(65)

#Gausian Blur
gausian_blur_transformation_13 = T.GaussianBlur(kernel_size = (7,13), sigma = (6 , 9))
gausian_blur_transformation_56 = T.GaussianBlur(kernel_size = (7,13), sigma = (5 , 8))

#Gausian Noise

def addnoise(input_image, noise_factor = 0.3):
inputs = T.ToTensor()(input_image)
noisy = inputs + torch.rand_like(inputs) * noise_factor
noisy = torch.clip (noisy,0,1.)
output_image = T.ToPILImage()
image = output_image(noisy)
return image

#Colour Jitter

colour_jitter_transformation_1 = T.ColorJitter(brightness=(0.5,1.5),contrast=(3),saturation=(0.3,1.5),hue=(-0.1,0.1))

colour_jitter_transformation_2 = T.ColorJitter(brightness=(0.7),contrast=(6),saturation=(0.9),hue=(-0.1,0.1))

colour_jitter_transformation_3 = T.ColorJitter(brightness=(0.5,1.5),contrast=(2),saturation=(1.4),hue=(-0.1,0.5))

#Random invert

random_invert_transform = T.RandomInvert()

#Main function that calls all the above functions to create 11 augmented images from one image

def augment_image(img_path):

#orig_image
orig_img = Image.open(Path(img_path))

#grayscale

grayscaled_image=grayscale_transform(orig_img)
#grayscaled_image.show()

#random rotation
random_rotation_transformation_45_image=random_rotation_transformation_45(orig_img)
#random_rotation_transformation_45_image.show()

random_rotation_transformation_85_image=random_rotation_transformation_85(orig_img)
#random_rotation_transformation_85_image.show()

random_rotation_transformation_65_image=random_rotation_transformation_65(orig_img)
#random_rotation_transformation_65_image.show()

#Gausian Blur

gausian_blurred_image_13_image = gausian_blur_transformation_13(orig_img)
#gausian_blurred_image_13_image.show()

gausian_blurred_image_56_image = gausian_blur_transformation_56(orig_img)
#gausian_blurred_image_56_image.show()

#Gausian Noise

gausian_image_3 = addnoise(orig_img)

#gausian_image_3.show()

gausian_image_6 = addnoise(orig_img,0.6)

#gausian_image_6.show()

gausian_image_9 = addnoise(orig_img,0.9)

#gausian_image_9.show()

#Color Jitter


colour_jitter_image_1 = colour_jitter_transformation_1(orig_img)

#colour_jitter_image_1.show()


colour_jitter_image_2 = colour_jitter_transformation_2(orig_img)

#colour_jitter_image_2.show()

colour_jitter_image_3 = colour_jitter_transformation_3(orig_img)

#colour_jitter_image_3.show()

return [orig_img,grayscaled_image,random_rotation_transformation_45_image,random_rotation_transformation_65_image,random_rotation_transformation_85_image,gausian_blurred_image_13_image,gausian_blurred_image_56_image,gausian_image_3,gausian_image_6,gausian_image_9,colour_jitter_image_1,colour_jitter_image_2,colour_jitter_image_3]

#augmented_images = augment_image(orig_img_path)

def creating_file_with_augmented_images(file_path_master_dataset,file_path_augmented_images):

master_dataset_folder = file_path_master_dataset
files_in_master_dataset = os.listdir(file_path_master_dataset)
augmented_images_folder = file_path_augmented_images

counter=0

for element in files_in_master_dataset:
os.mkdir(f"{augmented_images_folder}/{element}")
images_in_folder= os.listdir(f"{master_dataset_folder}/{element}")
counter = counter+1
counter2 = 0
for image in images_in_folder:
counter
required_images = augment_image(f"{master_dataset_folder}/{element}/{image}")
counter2=counter2+1
counter3 = 0
for augmented_image in required_images:
counter3 = counter3 +1
augmented_image = augmented_image.save(f"{augmented_images_folder}/{element}/{counter}_{counter2}_{counter3}_{image}")

"""images = augment_image("dog.png")

for element in images:
element.show()"""

#augmented dataset path
augmented_dataset = "/Users/software/Desktop/sem_6/Hieroglyphics_nlp/Code_image_augmentation/augmented_images_dataset"

# master dataset path
master_dataset = "/Users/software/Desktop/sem_6/Hieroglyphics_nlp/Code_image_augmentation/Master_dataset"

# run the program

creating_file_with_augmented_images(master_dataset,augmented_dataset)

After running the script all the augmented images are stored in the specified folder with the same file structure:

Hope the script worked on your machine, if you’re having trouble, go through the comments in the script and also make sure you have all the libraries used in the script installed and ready to go.

Happy augmenting…… (cringes in silence).

References :

--

--

Anush Somasundaram

Looking for interesting software projects (ML/DL/NLP anything).