DZone
Thanks for visiting DZone today,
Edit Profile
  • Manage Email Subscriptions
  • How to Post to DZone
  • Article Submission Guidelines
Sign Out View Profile
  • Post an Article
  • Manage My Drafts
Over 2 million developers have joined DZone.
Log In / Join
Refcards Trend Reports Events Over 2 million developers have joined DZone. Join Today! Thanks for visiting DZone today,
Edit Profile Manage Email Subscriptions Moderation Admin Console How to Post to DZone Article Submission Guidelines
View Profile
Sign Out
Refcards
Trend Reports
Events
Zones
Culture and Methodologies Agile Career Development Methodologies Team Management
Data Engineering AI/ML Big Data Data Databases IoT
Software Design and Architecture Cloud Architecture Containers Integration Microservices Performance Security
Coding Frameworks Java JavaScript Languages Tools
Testing, Deployment, and Maintenance Deployment DevOps and CI/CD Maintenance Monitoring and Observability Testing, Tools, and Frameworks
Partner Zones AWS Cloud
by AWS Developer Relations
Culture and Methodologies
Agile Career Development Methodologies Team Management
Data Engineering
AI/ML Big Data Data Databases IoT
Software Design and Architecture
Cloud Architecture Containers Integration Microservices Performance Security
Coding
Frameworks Java JavaScript Languages Tools
Testing, Deployment, and Maintenance
Deployment DevOps and CI/CD Maintenance Monitoring and Observability Testing, Tools, and Frameworks
Partner Zones
AWS Cloud
by AWS Developer Relations
Securing Your Software Supply Chain with JFrog and Azure
Register Today

Trending

  • Scaling Site Reliability Engineering (SRE) Teams the Right Way
  • How To Use Pandas and Matplotlib To Perform EDA In Python
  • Never Use Credentials in a CI/CD Pipeline Again
  • How To Use Geo-Partitioning to Comply With Data Regulations and Deliver Low Latency Globally

Trending

  • Scaling Site Reliability Engineering (SRE) Teams the Right Way
  • How To Use Pandas and Matplotlib To Perform EDA In Python
  • Never Use Credentials in a CI/CD Pipeline Again
  • How To Use Geo-Partitioning to Comply With Data Regulations and Deliver Low Latency Globally

Getting Predictions After Each Batch in Keras and TF2.2

In this article, see how to get predictions (and possibly more) the easy way in Keras training code.

Ran Levy user avatar by
Ran Levy
·
Sep. 18, 20 · Tutorial
Like (2)
Save
Tweet
Share
4.59K Views

Join the DZone community and get the full member experience.

Join For Free

In this tutorial, I will demonstrate how I have managed to get the predictions after each training batch in Keras model. 

Using Tensorflow training code it is pretty easy, since we implement the training loop, in which we call:

Java
 




x


 
1
 _, batch_pred = sess.run([optimizer, pred], feed_dict={x: batch_x,y: batch_y})



However, in Keras, we call the fit function, which does all the training for us. It is pretty convenient when starting training from scratch, but when you need to go deeper in some projects, this convenience turns into significant discomfort.

In my project, I wanted to change the data generator while training, based on the network predictions. Although it is not common, it is pretty basic functionality to expect from Keras: 

Get the predictions, or even the activations of the internal layers, after each batch . 

Digging up the internet did not end with a solution. There are different solutions in different versions of TF (TF1 [1], Tf2 [2] and TF2.2[3] - the last is only partial). Other solutions offered to create a custom Keras training loop, while the default one uses internal Keras logic.

All the solutions stated before has two major drawbacks:

  1. They do not endure Tensorflow / Keras internal changes and versions divergence.
  2. They are too complicated.

I would not implement a Keras model from scratch, but the open-source that i'm relying on did. I did not want to make major changes in the project.   

As usual, I was looking for a practical and easy solution to get predictions, that will last versions upgrades. I started digging into Keras code, looking for a way to get the predictions after each batch.

The major challenge relies on the fact that Keras does not call the train_step in every batch, but create its own implementation of the function supplied. This results in that even prints are not presented to the screen during training with the fit function.

I found some potential places to change Keras code, that will help me get the predictions. However, this will work on other environments besides myn. 

I kept searching for a solution, until I found the "hack":

Return the predictions via the metrics parameter.

 It is may not be the original purpose of Keras'es author, but it works.

The first step is to create a custom simple MyModelthat inherits from Model, and overrides the train_step function:

Java
 




xxxxxxxxxx
1


 
1
from tensorflow.keras.models import Model
2
from tensorflow.python.keras.engine.training import _minimize
3
from tensorflow.python.keras.engine import data_adapter
4
from tensorflow.python.eager import backprop


Java
 




xxxxxxxxxx
1
26


 
1
class MyModel(Model):
2
    
3
    #base code is taken from 'tensorflow.python.keras.engine.training'
4
    @tf.function
5
    def train_step(self, data):
6
        data = data_adapter.expand_1d(data)
7
        x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
8

          
9
        with backprop.GradientTape() as tape:
10
            y_pred = self(x, training=True)
11
            loss = self.compiled_loss(
12
                y, y_pred, sample_weight, regularization_losses=self.losses)
13
        # For custom training steps, users can just write:
14
        #   trainable_variables = self.trainable_variables
15
        #   gradients = tape.gradient(loss, trainable_variables)
16
        #   self.optimizer.apply_gradients(zip(gradients, trainable_variables))
17
        # The _minimize call does a few extra steps unnecessary in most cases,
18
        # such as loss scaling and gradient clipping.
19
        _minimize(self.distribute_strategy, tape, self.optimizer, loss,
20
                  self.trainable_variables)
21

          
22
        self.compiled_metrics.update_state(y, y_pred, sample_weight)
23
        metrics_ = {m.name: m.result() for m in self.metrics}
24
        metrics_['pred']=y_pred
25
        metrics_['inputs']=x
26
        return metrics_



The second step is to implement a Keras callback:

Java
 




xxxxxxxxxx
1


 
1
class PredictionHistory(Callback):
2
    def __init__(self):
3
        pass
4
    def on_train_batch_end(self, epoch, logs={}):
5
        #the predictions are in the logs parameter: logs['pred'] 



Now all you need is to call Keras fit function with the parameter:

Java
 




xxxxxxxxxx
1


 
1
callbacks=[PredictionHistory()]) #combined with other callbacks suck as checkpoint,tensorboard


 

And voilà! each batch ends with the predictions.

Note that you can add to the model's outputs each internal layer that you want, and thus getting any internal activation you want to investigate.

Next time, use native Tensorflow or PyTorch, ok?

Keras

Opinions expressed by DZone contributors are their own.

Trending

  • Scaling Site Reliability Engineering (SRE) Teams the Right Way
  • How To Use Pandas and Matplotlib To Perform EDA In Python
  • Never Use Credentials in a CI/CD Pipeline Again
  • How To Use Geo-Partitioning to Comply With Data Regulations and Deliver Low Latency Globally

Comments

Partner Resources

X

ABOUT US

  • About DZone
  • Send feedback
  • Careers
  • Sitemap

ADVERTISE

  • Advertise with DZone

CONTRIBUTE ON DZONE

  • Article Submission Guidelines
  • Become a Contributor
  • Visit the Writers' Zone

LEGAL

  • Terms of Service
  • Privacy Policy

CONTACT US

  • 600 Park Offices Drive
  • Suite 300
  • Durham, NC 27709
  • support@dzone.com

Let's be friends: