DZone
Thanks for visiting DZone today,
Edit Profile
  • Manage Email Subscriptions
  • How to Post to DZone
  • Article Submission Guidelines
Sign Out View Profile
  • Post an Article
  • Manage My Drafts
Over 2 million developers have joined DZone.
Log In / Join
Refcards Trend Reports
Events Video Library
Refcards
Trend Reports

Events

View Events Video Library

Related

  • Breaking the Vendor Lock in Network Automation: A Pure Python Architecture
  • An AI-Driven Architecture for Autonomous Network Operations (NetOps)
  • Real-Time Computer Vision on macOS: Accelerating Vision Transformers
  • Advancing Robot Vision and Control

Trending

  • Implementing Observability in Distributed Systems Using OpenTelemetry
  • Stateless JWT Auth Microservice Architecture With Spring Boot 3 and Redis Sentinel
  • Mocking Kafka for Local Spring Development
  • Build Self-Managing Data Pipelines With an LLM Agent
  1. DZone
  2. Data Engineering
  3. AI/ML
  4. How We Trained a Neural Network to Generate Shadows in a Photo: Part 3

How We Trained a Neural Network to Generate Shadows in a Photo: Part 3

In this article, we train a neural network to generate shadows in photos.

By 
Artyom Nazarenko user avatar
Artyom Nazarenko
·
Feb. 23, 21 · Tutorial
Likes (4)
Comment
Save
Tweet
Share
6.2K Views

Join the DZone community and get the full member experience.

Join For Free

In this series, Artem Nazarenko, Computer Vision Engineer at Everypixel, shows you how you can implement the architecture of a neural network. In the first part, we were talking about the working principles of GAN and methods of collecting datasets for training, the second part was about preparing for GAN training. Today, we are going to start training. 

Training

We declare datasets and dataloaders for loading data and provide the device on which the network will be trained.

Python
 




x


 
1
# The number of images that run through the neural network at one time
2
batch_size = 8
3
dataset_path = '/path/to/your/dataset'
4
train_path = osp.join(dataset_path, 'train')
5
test_path = osp.join(dataset_path, 'test')
6

          
7
# Declare datasets
8
train_dataset = ARDataset(train_path,\
9
                          augmentation=get_training_augmentation(),\
10
                          preprocessing=get_preprocessing(),)
11
valid_dataset = ARDataset(test_path, \
12
                          augmentation=get_validation_augmentation(),\
13
                          preprocessing=get_preprocessing(),)
14

          
15
# Declare dataloaders
16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
17
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)



Provide the device on which we will train the network:

Python
 




xxxxxxxxxx
1


 
1
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')



We train attention and shadow generation blocks separately.

Attention block training. We take U-Net as the attention block model and import the architecture from the segmentation_models.pytorch repository. To improve the quality of the network, replace the standard encoding part of the U-Net with the resnet34 classifier network.

Since the attention block accepts a shadow-free image and a mask of the inserted object at the input, we will replace the first convolutional layer in the model: a 4-channel tensor (3 color channels + 1 black-and-white) is sent to the module's input.

Python
 




x


 
1
# Declare a U-Net model with two classes at the output — two masks (neighboring objects and their shadows)
2
model = smp.Unet(encoder_name='resnet34', classes=2, activation='sigmoid',)
3
# Replace the first convolutional layer in the model — there should be four channels at the input
4
model.encoder.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), \
5
                                padding=(3, 3), bias=False)



Declare the loss function, metric and optimizer.

Python
 




xxxxxxxxxx
1


 
1
loss = smp.utils.losses.DiceLoss()
2
metric = smp.utils.metrics.IoU(threshold=0.5)
3
optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=1e-4),])



Create a function to train the attention block. The training is standard. It consists of three cycles: a cycle by epochs, a training cycle by batches, and a validation cycle by batches.

At each iteration of the dataloader, a direct run of the data through the model and obtaining predictions are performed. Then, the loss functions and metrics are calculated, after which a reverse pass of the learning algorithm (backpropagation of the error) is done, and the weights are updated.

Python
 




x


 
1
def train(n_epoch, train_loader, valid_loader, model_path, model, loss,\
2
          metric, optimizer, device):
3
    """ Network learning function.
4

          
5
    n_epoch — number of epochs
6
    train_loader — dataloader for training samples
7
    valid_loader — dataloader for validation samples
8
    model_path — path to save the model
9
    model — pre-announced model
10
    loss — loss function
11
    metric — metric
12
    optimizer — optimizer
13
    device — specific torch.device
14
    """
15
    model.to(device)
16

          
17
    max_score = 0
18
    total_train_steps = len(train_loader)
19
    total_valid_steps = len(valid_loader)
20

          
21
    # Start the training cycle
22
    print('Start training!')
23

          
24
    for epoch in range(n_epoch):
25
        # Put the model into training mode
26
        model.train()
27
        train_loss = 0.0
28
        train_metric = 0.0
29

          
30
        # Batch training cycle
31
        for data in train_loader:
32
            noshadow_image = data[0][:, :3].to(device)
33
            robject_mask = torch.unsqueeze(data[1][:, 0], 1).to(device)
34
            rshadow_mask = torch.unsqueeze(data[1][:, 1], 1).to(device)
35
            mask = torch.unsqueeze(data[1][:, 2], 1).to(device)
36

          
37
            # Run through the model
38
            model_input = torch.cat((noshadow_image, mask), axis=1)
39
            model_output = model(model_input)
40

          
41
            # Compare the model output with ground truth data
42
            ground_truth = torch.cat((robject_mask, rshadow_mask), axis=1)
43
            loss_result = loss(ground_truth, model_output)
44
            train_metric += metric(ground_truth, model_output).item()
45

          
46
            optimizer.zero_grad()
47
            loss_result.backward()
48
            optimizer.step()
49

          
50
            train_loss += loss_result.item()
51

          
52
        # Put the model in eval mode
53
        model.eval()
54
        valid_loss = 0.0
55
        valid_metric = 0.0
56

          
57
        # Batch validation cycle
58
        for data in valid_loader:
59
            noshadow_image = data[0][:, :3].to(device)
60
            robject_mask = torch.unsqueeze(data[1][:, 0], 1).to(device)
61
            rshadow_mask = torch.unsqueeze(data[1][:, 1], 1).to(device)
62
            mask = torch.unsqueeze(data[1][:, 2], 1).to(device)
63

          
64
            # Run through the model
65
            model_input = torch.cat((noshadow_image, mask), axis=1)
66

          
67
            with torch.no_grad():
68
                model_output = model(model_input)
69

          
70
            # Compare the model output with ground truth data
71
            ground_truth = torch.cat((robject_mask, rshadow_mask), axis=1)
72
            loss_result = loss(ground_truth, model_output)
73
            valid_metric += metric(ground_truth, model_output).item()
74
            valid_loss += loss_result.item()
75

          
76
        train_loss = train_loss / total_train_steps
77
        train_metric = train_metric / total_train_steps
78
        valid_loss = valid_loss / total_valid_steps
79
        valid_metric = valid_metric / total_valid_steps
80

          
81
        print(f'\nEpoch {epoch}, train_loss: {train_loss}, train_metric: {train_metric}, valid_loss: {valid_loss}, valid_metric: {valid_metric}')
82

          
83
        # If you got a new maximum in accuracy, save the model
84
        if max_score < valid_metric:
85
            max_score = valid_metric
86
            torch.save(model.state_dict(), model_path)
87
            print('Model saved!')
88

          
89

          
90
# Call the function:
91

          
92
# Number of epochs
93
n_epoch = 10
94
# Path to save the model
95
model_path = '/path/for/model/saving' 
96

          
97
train(n_epoch=n_epoch,
98
      train_loader=train_loader,
99
      valid_loader=valid_loader,
100
      model_path=model_path,
101
      model=model,
102
      loss=loss,
103
      metric=metric,
104
      optimizer=optimizer,
105
      device=device)



After the training of the attention block is completed, proceed to the main part of the network.

Shadow generation block training. As a model of the shadow generation block, we will similarly take U-Net and a lighter network – resnet18 as an encoder.

Since at the input shadow generation block accepts a shadow-free image and 3 masks (the mask of the inserted object, the mask of neighboring objects and the mask of their shadows), we will replace the first convolutional layer in the model: the module receives a 6-channel tensor (3 color channels + 3 black-white ones) at the input.

Behind the U-Net, we add 4 refinement blocks at the end. One block consists of a sequence: BatchNorm2d, ReLU and Conv2d.

Declare a generator class.

Python
 




xxxxxxxxxx
1
35


 
1
class Generator_with_Refin(nn.Module):
2
    def __init__(self, encoder):
3
        """ Generator initialization."""
4
        super(Generator_with_Refin, self).__init__()
5
        self.generator = smp.Unet(
6
            encoder_name=encoder,
7
            classes=1,
8
            activation='identity',
9
        )
10
        self.generator.encoder.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), \
11
                                                 stride=(2, 2), padding=(3, 3), \
12
                                                 bias=False)
13
        self.generator.segmentation_head = nn.Identity()
14
        self.SG_head = nn.Conv2d(in_channels=16, out_channels=3, \
15
                                 kernel_size=3, stride=1, padding=1)
16

          
17
        self.refinement = torch.nn.Sequential()
18
        for i in range(4):
19
            self.refinement.add_module(f'refinement{3*i+1}', nn.BatchNorm2d(16))
20
            self.refinement.add_module(f'refinement{3*i+2}', nn.ReLU())
21
            refinement3 = nn.Conv2d(in_channels=16, out_channels=16, \
22
                                    kernel_size=3, stride=1, padding=1)
23
            self.refinement.add_module(f'refinement{3*i+3}', refinement3)
24

          
25
        self.output1 = nn.Conv2d(in_channels=16, out_channels=3, kernel_size=3, \
26
                                 stride=1, padding=1)
27

          
28
        
29
    def forward(self, x):
30
        """ Direct pass of data through the network."""
31
        x = self.generator(x)
32
        out1 = self.SG_head(x)
33

          
34
        x = self.refinement(x)
35
        x = self.output1(x)
36
        return out1, x



Declare a discriminator class.

Python
 




xxxxxxxxxx
1
37


 
1
class Discriminator(nn.Module):
2
    def __init__(self, input_shape):
3
        super(Discriminator, self).__init__()
4

          
5
        self.input_shape = input_shape
6
        in_channels, in_height, in_width = self.input_shape
7
        patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
8
        self.output_shape = (1, patch_h, patch_w)
9

          
10
        
11
        def discriminator_block(in_filters, out_filters, first_block=False):
12
            layers = []
13
            layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, \
14
                                    stride=1, padding=1))
15
            if not first_block:
16
                layers.append(nn.BatchNorm2d(out_filters))
17
            layers.append(nn.LeakyReLU(0.2, inplace=True))
18
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, \
19
                                    stride=2, padding=1))
20
            layers.append(nn.BatchNorm2d(out_filters))
21
            layers.append(nn.LeakyReLU(0.2, inplace=True))
22
            return layers
23

          
24
        layers = []
25
        in_filters = in_channels
26
        for i, out_filters in enumerate([64, 128, 256, 512]):
27
            layers.extend(discriminator_block(in_filters, out_filters, \
28
                                              first_block=(i == 0)))
29
            in_filters = out_filters
30

          
31
        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, \
32
                                padding=1))
33

          
34
        self.model = nn.Sequential(*layers)
35

          
36
        
37
    def forward(self, img):
38
        return self.model(img)



Declare generator and discriminator model objects, as well as loss functions and optimizers for the generator and discriminator.

Python
 




xxxxxxxxxx
1


 
1
generator = Generator_with_Refin('resnet18')
2
discriminator = Discriminator(input_shape=(3,256,256))
3

          
4
l2loss = nn.MSELoss()
5
perloss = ContentLoss(feature_extractor="vgg16", layers=("relu3_3", ))
6
GANloss = nn.MSELoss()
7

          
8
optimizer_G = torch.optim.Adam([dict(params=generator.parameters(), lr=1e-4),])
9
optimizer_D = torch.optim.Adam([dict(params=discriminator.parameters(), lr=1e-6),])



Everything is ready for training. Provide a function for training the SG block. Calling it will be similar to calling the attention learning function.

Python
 




x


 
1
def train(generator, discriminator, device, n_epoch, optimizer_G, optimizer_D, train_loader, valid_loader, scheduler, losses, models_paths, bettas, writer):
2
    """Function for training the SG block
3

          
4
        generator — generator model
5
        discriminator — discriminator model
6
        device — torch-device for training
7
        n_epoch — number of epochs
8
        optimizer_G — optimizer for the generator model
9
        optimizer_D — optimizer for the discriminator model
10
        train_loader — dataloader for training samples
11
        valid_loader — dataloader for validation samples
12
        scheduler — scheduler to change the learning rate
13
        losses — list of loss functions
14
        models_paths — list of paths for saving models
15
        bettas — list of coefficients for loss function
16
        writer — tensorboard writer
17
    """
18
    # Transferring the models to the GPU
19
    generator.to(device)
20
    discriminator.to(device)
21

          
22
    # For the validation minimum
23
    val_common_min = np.inf
24

          
25
    print('Start training!')
26
    for epoch in range(n_epoch):
27
        # Put the models into training mode
28
        generator.train()
29
        discriminator.train()
30

          
31
        # Lists for Loss Function Values
32
        train_l2_loss = []; train_per_loss = []; train_common_loss = []; 
33
        train_D_loss = []; valid_l2_loss = []; valid_per_loss = []; 
34
        valid_common_loss = [];
35

          
36
        print('Cycle by batches:')
37
        for batch_i, data in enumerate(tqdm(train_loader)):
38
            noshadow_image = data[2][:, :3].to(device)
39
            shadow_image = data[2][:, 3:].to(device)
40
            robject_mask = torch.unsqueeze(data[3][:, 0], 1).to(device)
41
            rshadow_mask = torch.unsqueeze(data[3][:, 1], 1).to(device)
42
            mask = torch.unsqueeze(data[3][:, 2], 1).to(device)
43

          
44
            # Prepare the input tensor for the model
45
            model_input = torch.cat((noshadow_image, mask, robject_mask, rshadow_mask), axis=1)
46
            # ------------ Train the generator ----------------------------------
47
            shadow_mask_tensor1, shadow_mask_tensor2 = generator(model_input)
48
            result_nn_tensor1 = torch.add(noshadow_image, shadow_mask_tensor1)
49
            result_nn_tensor2 = torch.add(noshadow_image, shadow_mask_tensor2)
50

          
51
            for_per_shadow_image_tensor = torch.sigmoid(shadow_image)
52
            for_per_result_nn_tensor1 = torch.sigmoid(result_nn_tensor1)
53
            for_per_result_nn_tensor2 = torch.sigmoid(result_nn_tensor2)
54

          
55
            # Adversarial ground truths
56
            valid = Variable(torch.cuda.FloatTensor(np.ones((data[2].size(0), *discriminator.output_shape))), requires_grad=False)
57
            fake = Variable(torch.cuda.FloatTensor(np.zeros((data[2].size(0), *discriminator.output_shape))), requires_grad=False)
58

          
59
            # Calculate loss functions
60
            l2_loss = losses[0](shadow_image, result_nn_tensor1) + losses[0](shadow_image, result_nn_tensor2)
61
            per_loss = losses[1](for_per_shadow_image_tensor, for_per_result_nn_tensor1) + losses[1](for_per_shadow_image_tensor, for_per_result_nn_tensor2)
62
            gan_loss = losses[2](discriminator(result_nn_tensor2), valid)
63
            common_loss = bettas[0] * l2_loss + bettas[1] * per_loss + bettas[2] * gan_loss
64

          
65
            optimizer_G.zero_grad()
66
            common_loss.backward()
67
            optimizer_G.step()
68

          
69
            # ------------ Train the discriminator ------------------------------
70
            optimizer_D.zero_grad()
71

          
72
            loss_real = losses[2](discriminator(shadow_image), valid)
73
            loss_fake = losses[2](discriminator(result_nn_tensor2.detach()), fake)
74
            loss_D = (loss_real + loss_fake) / 2
75

          
76
            loss_D.backward()
77
            optimizer_D.step()
78

          
79
            # ------------------------------------------------------------------
80
            train_l2_loss.append((bettas[0] * l2_loss).item())
81
            train_per_loss.append((bettas[1] * per_loss).item())
82
            train_D_loss.append((bettas[2] * loss_D).item())
83
            train_common_loss.append(common_loss.item())
84

          
85
        # Put the generator into eval mode
86
        generator.eval()
87

          
88
        # Validation
89
        for batch_i, data in enumerate(valid_loader):
90
            noshadow_image = data[2][:, :3].to(device)
91
            shadow_image = data[2][:, 3:].to(device)
92
            robject_mask = torch.unsqueeze(data[3][:, 0], 1).to(device)
93
            rshadow_mask = torch.unsqueeze(data[3][:, 1], 1).to(device)
94
            mask = torch.unsqueeze(data[3][:, 2], 1).to(device)
95

          
96
            # Prepare the input for the model
97
            model_input = torch.cat((noshadow_image, mask, robject_mask, rshadow_mask), axis=1)
98

          
99
            with torch.no_grad():
100
                shadow_mask_tensor1, shadow_mask_tensor2 = generator(model_input)
101

          
102
            result_nn_tensor1 = torch.add(noshadow_image, shadow_mask_tensor1)
103
            result_nn_tensor2 = torch.add(noshadow_image, shadow_mask_tensor2)
104

          
105
            for_per_result_shadow_image_tensor = torch.sigmoid(shadow_image)
106
            for_per_result_nn_tensor1 = torch.sigmoid(result_nn_tensor1)
107
            for_per_result_nn_tensor2 = torch.sigmoid(result_nn_tensor2)
108

          
109
            # Calculate loss functions
110
            l2_loss = losses[0](shadow_image, result_nn_tensor1) + losses[0](shadow_image, result_nn_tensor2)
111
            per_loss = losses[1](for_per_result_shadow_image_tensor, for_per_result_nn_tensor1) + losses[1](for_per_result_shadow_image_tensor, for_per_result_nn_tensor2)
112
            common_loss = bettas[0] * l2_loss + bettas[1] * per_loss
113

          
114
            valid_per_loss.append((bettas[1] * per_loss).item())
115
            valid_l2_loss.append((bettas[0] * l2_loss).item())
116
            valid_common_loss.append(common_loss.item())
117

          
118
        # Average the values of the loss functions
119
        tr_l2_loss = np.mean(train_l2_loss)
120
        val_l2_loss = np.mean(valid_l2_loss)
121
        tr_per_loss = np.mean(train_per_loss)
122
        val_per_loss = np.mean(valid_per_loss)
123
        tr_common_loss = np.mean(train_common_loss)
124
        val_common_loss = np.mean(valid_common_loss)
125
        tr_D_loss = np.mean(train_D_loss)
126

          
127
        # Add results to tensorboard
128
        writer.add_scalar('tr_l2_loss', tr_l2_loss, epoch)
129
        writer.add_scalar('val_l2_loss', val_l2_loss, epoch)
130
        writer.add_scalar('tr_per_loss', tr_per_loss, epoch)
131
        writer.add_scalar('val_per_loss', val_per_loss, epoch)
132
        writer.add_scalar('tr_common_loss', tr_common_loss, epoch)
133
        writer.add_scalar('val_common_loss', val_common_loss, epoch)
134
        writer.add_scalar('tr_D_loss', tr_D_loss, epoch)
135

          
136
        # Print information
137
        print(f'\nEpoch {epoch}, tr_common loss: {tr_common_loss:.4f}, val_common loss: {val_common_loss:.4f}, D_loss {tr_D_loss:.4f}')
138

          
139
        if val_common_loss <= val_common_min:
140
            # Save the best model
141
            torch.save(generator.state_dict(), models_paths[0])
142
            torch.save(discriminator.state_dict(), models_paths[1])
143
            val_common_min = val_common_loss
144
            print(f'Model saved!')
145

          
146
        # Make a Scheduler Step
147
        scheduler.step(val_common_loss)



Training Process

Visualization of the learning process

Graphs, general information. For training, I used a GTX 1080Ti graphics card on the Hostkey server. In the process, I tracked the change in the loss functions for the plotted graphs using the tensorboard utility. Below, the figures show training graphs based on the training and validation samples.

Training Graphs — Training Samples

The second figure is especially useful because the validation samples are not used in the generator training process. They are independent. The training graphs show that it reached the plateau at approx. the 200-250th epoch. Here it was already possible to slow down the training of the generator since the loss function was not monotonic.

However, it is useful to look at the training graphs on a logarithmic scale as it shows the monotony of the graph more clearly. According to the graph of the logarithm of the validation loss function, we can see that it was too early to stop learning at approx. the 200-250th epoch. It could have been done later, at the 400th epoch.

 

Training Graphs — Validation Samples


For clarity of the experiment, the predicted picture was periodically saved (see the gif of the visualization of the learning process above).


Some difficulties. During the training process, we had to solve a simple problem — incorrect weighting of the loss functions.

Since our final loss function consists of the weighted sum of the other loss functions, the contribution of each of them to the total must be adjusted separately by setting the coefficients for them. The best option is to take the coefficients suggested in the original article.

If the balancing of the loss functions is wrong, we can get unsatisfactory results. For example, if too strong a contribution is set for L2, and then the training of the neural network can even come to a standstill. L2 converges quickly enough, but at the same time, it is undesirable to completely remove it from the total amount - the output shadow will be less realistic, less consistent in color and transparency.

An example of a generated shadow in the absence of an L2-loss contribution

An example of a generated shadow in the absence of an L2-loss contribution


The picture shows the ground truth image on the left and the generated image on the right.

Inference. For prediction and testing, we will combine the attention and SG models into one ARShadowGAN class. 

Python
 




xxxxxxxxxx
1
42


 
1
class ARShadowGAN(nn.Module):
2
    def __init__(self, model_path_attention, model_path_SG, encoder_att='resnet34', \
3
                 encoder_SG='resnet18', device='cuda:0'):
4
        super(ARShadowGAN, self).__init__()
5
6
        self.device = torch.device(device)
7
        self.model_att = smp.Unet(
8
            classes=2,
9
            encoder_name=encoder_att,
10
            activation='sigmoid'
11
        )
12
        self.model_att.encoder.conv1 = nn.Conv2d(4, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
13
        self.model_att.load_state_dict(torch.load(model_path_attention))
14
        self.model_att.to(device)
15
16
        self.model_SG = Generator_with_Refin(encoder_SG)
17
        self.model_SG.load_state_dict(torch.load(model_path_SG))
18
        self.model_SG.to(device)
19
20
        
21
    def forward(self, tensor_att, tensor_SG):
22
        self.model_att.eval()
23
24
        with torch.no_grad():
25
            robject_rshadow_tensor = self.model_att(tensor_att)
26
27
        robject_rshadow_np = robject_rshadow_tensor.cpu().numpy()
28
29
        robject_rshadow_np[robject_rshadow_np >= 0.5] = 1
30
        robject_rshadow_np[robject_rshadow_np < 0.5] = 0
31
        robject_rshadow_np = 2 * (robject_rshadow_np - 0.5)
32
33
        robject_rshadow_tensor = torch.cuda.FloatTensor(robject_rshadow_np)
34
35
        tensor_SG = torch.cat((tensor_SG, robject_rshadow_tensor), axis=1)
36
37
        self.model_SG.eval()
38
        with torch.no_grad():
39
            output_mask1, output_mask2 = self.model_SG(tensor_SG)
40
41
        result = torch.add(tensor_SG[:,:3, ...], output_mask2)
42
43
        return result, output_mask2



The inference code is below.

Python
 




x


 
1
# Specify the paths to data and checkpoints
2
dataset_path = '/content/arshadowgan/uploaded'
3
result_path = '/content/arshadowgan/uploaded/shadow'
4

          
5
path_att = '/content/drive/MyDrive/ARShadowGAN-like/attention.pth'
6
path_SG = '/content/drive/MyDrive/ARShadowGAN-like/SG_generator.pth'
7

          
8
# Declare dataset and dataloader
9
dataset = ARDataset(dataset_path, augmentation=get_validation_augmentation(256), preprocessing=get_preprocessing(), is_train=False)
10
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
11

          
12
# Provide the device
13
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
14

          
15
# Declare the complete model
16
model = ARShadowGAN(
17
    encoder_att='resnet34',
18
    encoder_SG='resnet18',
19
    model_path_attention=path_att,
20
    model_path_SG=path_SG,
21
    device=device
22
)
23
# Put it into testing mode
24
model.eval()
25

          
26
# Prediction
27
for i, data in enumerate(dataloader):
28
    tensor_att = torch.cat((data[0][:, :3], torch.unsqueeze(data[1][:, -1], axis=1)), axis=1).to(device)
29
    tensor_SG = torch.cat((data[2][:, :3], torch.unsqueeze(data[3][:, -1], axis=1)), axis=1).to(device)
30

          
31
    with torch.no_grad():
32
        result, shadow_mask = model(tensor_att, tensor_SG)
33

          
34
        shadow_mask = np.uint8(127.5*shadow_mask[0].cpu().numpy().transpose((1,2,0)) + 1.0)
35
        output_image = np.uint8(127.5 * (result.cpu().numpy()[0].transpose(1,2,0) + 1.0))
36

          
37
        cv2.imwrite(osp.join(result_path, 'test.png'), output_image)
38
        print('result saved: ' + result_path + '/test.png')



Conclusion

This article discusses a generative adversarial network by the example of solving one of the ambitious and difficult tasks at the junction of Augmented Reality and Computer Vision. In general, the resulting model can generate shadows, although not always perfect.

Note that GAN is not the only way to generate shadows. There are other approaches that, for example, use 3D object reconstruction techniques, differentiated rendering, etc. 

The whole above code is in the repository. The examples of launching are in Google Colab Notebook.

P.S. I would be happy to answer any questions you may have and to receive your feedback. Thank you for your attention!

neural network Network Generative adversarial network Blocks Python (language)

Opinions expressed by DZone contributors are their own.

Related

  • Breaking the Vendor Lock in Network Automation: A Pure Python Architecture
  • An AI-Driven Architecture for Autonomous Network Operations (NetOps)
  • Real-Time Computer Vision on macOS: Accelerating Vision Transformers
  • Advancing Robot Vision and Control

Partner Resources

×

Comments

The likes didn't load as expected. Please refresh the page and try again.

  • RSS
  • X
  • Facebook

ABOUT US

  • About DZone
  • Support and feedback
  • Community research

ADVERTISE

  • Advertise with DZone

CONTRIBUTE ON DZONE

  • Article Submission Guidelines
  • Become a Contributor
  • Core Program
  • Visit the Writers' Zone

LEGAL

  • Terms of Service
  • Privacy Policy

CONTACT US

  • 3343 Perimeter Hill Drive
  • Suite 215
  • Nashville, TN 37211
  • [email protected]

Let's be friends:

  • RSS
  • X
  • Facebook