Getting Predictions After Each Batch in Keras and TF2.2
In this article, see how to get predictions (and possibly more) the easy way in Keras training code.
Join the DZone community and get the full member experience.Join For Free
In this tutorial, I will demonstrate how I have managed to get the predictions after each training batch in Keras model.
Using Tensorflow training code it is pretty easy, since we implement the training loop, in which we call:
However, in Keras, we call the
fit function, which does all the training for us. It is pretty convenient when starting training from scratch, but when you need to go deeper in some projects, this convenience turns into significant discomfort.
In my project, I wanted to change the data generator while training, based on the network predictions. Although it is not common, it is pretty basic functionality to expect from Keras:
Get the predictions, or even the activations of the internal layers, after each batch .
Digging up the internet did not end with a solution. There are different solutions in different versions of TF (TF1 , Tf2  and TF2.2 - the last is only partial). Other solutions offered to create a custom Keras training loop, while the default one uses internal Keras logic.
All the solutions stated before has two major drawbacks:
- They do not endure Tensorflow / Keras internal changes and versions divergence.
- They are too complicated.
I would not implement a Keras model from scratch, but the open-source that i'm relying on did. I did not want to make major changes in the project.
As usual, I was looking for a practical and easy solution to get predictions, that will last versions upgrades. I started digging into Keras code, looking for a way to get the predictions after each batch.
The major challenge relies on the fact that Keras does not call the
train_step in every batch, but create its own implementation of the function supplied. This results in that even
prints are not presented to the screen during training with the
I found some potential places to change Keras code, that will help me get the predictions. However, this will work on other environments besides myn.
I kept searching for a solution, until I found the "hack":
Return the predictions via the metrics parameter.
It is may not be the original purpose of Keras'es author, but it works.
The first step is to create a custom simple
MyModelthat inherits from
Model, and overrides the
The second step is to implement a Keras callback:
Now all you need is to call Keras
fit function with the parameter:
And voilà! each batch ends with the predictions.
Note that you can add to the model's outputs each internal layer that you want, and thus getting any internal activation you want to investigate.
Next time, use native Tensorflow or PyTorch, ok?
Opinions expressed by DZone contributors are their own.