How We Trained a Neural Network to Generate Shadows in a Photo: Part 2
In this article, we prepare for training and look at loss functions and metrics.
Join the DZone community and get the full member experience.Join For Free
In this series, Artem Nazarenko, Computer Vision Engineer at Everypixel shows you how you can implement the architecture of a neural network. In the first part, we were talking about the working principles of GAN and methods of collecting datasets for training. This part is about preparing for GAN training.
Loss Functions and Metrics
Attention. At this point, we deviate from the reference article. We take the loss function to solve the segmentation problem. Generation of attention maps (masks) can be considered as a classic image segmentation problem. We take Dice Loss as the loss function. It is well resilient to unbalanced data.
We take IoU (Intersection over Union) as a metric.
Learn more about Dice Loss and IoU.
Shadow Generation. We take the loss function for the generation block similar to the one given in the original article. It consists of a weighted sum of three loss functions: L2, Lper and Ladv:
L2 estimates the distance from the ground truth image to the generated ones (before and after the refinement block, denoted as R).
Lper (perceptual loss) is a loss function that calculates the distance between feature maps of the VGG16 network when images are run through it. The difference is considered the standard MSE between the ground truth image with a shadow and the generated images — before and after the refinement block, respectively.
Ladv is a standard adversarial loss that takes into account the competitive nature of the generator and the discriminator. D (.) is the probability of belonging to the "real image" class. During training, the generator tries to minimize Ladv, while the discriminator, on the contrary, tries to maximize it.
Installing the required modules. To implement ARShadowGAN-like, we will use Python deep learning library – PyTorch.
Libraries in use. We start the work by installing the required modules:
- to import U-Net architecture,
- for augmentations,
- to import the required loss function,
- for rendering images inside Jupyter notebooks,
- to work with arrays,
- to work with images,
- to visualize training schedules,
- for neural networks and deep learning,
- to import models, for deep learning,
- for progress bar visualization.
Dataset: structure, download, unpacking. For training and testing purposes, I will use a ready-made dataset. The data is already split into train and test samples. We download and unpack it.
The folder structure in the dataset is as follows. Each of the samples contains five folders with the following images:
- noshadow (shadow-free images),
- shadow (images with shadows),
- mask (masks of inserted objects),
- robject (neighboring objects or occluders),
- rshadow (shadows from neighboring objects).
You can prepare your dataset with a similar file structure.
We prepare the ARDataset class for image processing and issuing the i-th batch of data on request.
Then, we provide the class. The main function of the class is
__getitem __ (). It returns the i-th image and the corresponding mask on request.
Declare augmentations and functions for data processing. We take augmentations from the albumentations repository.
In the next and final part, we start training.
Opinions expressed by DZone contributors are their own.