Getting Predictions After Each Batch in Keras and TF2.2
In this article, see how to get predictions (and possibly more) the easy way in Keras training code.
Join the DZone community and get the full member experience.
Join For FreeIn this tutorial, I will demonstrate how I have managed to get the predictions after each training batch in Keras model.
Using Tensorflow training code it is pretty easy, since we implement the training loop, in which we call:
_, batch_pred = sess.run([optimizer, pred], feed_dict={x: batch_x,y: batch_y})
However, in Keras, we call the fit
function, which does all the training for us. It is pretty convenient when starting training from scratch, but when you need to go deeper in some projects, this convenience turns into significant discomfort.
In my project, I wanted to change the data generator while training, based on the network predictions. Although it is not common, it is pretty basic functionality to expect from Keras:
Get the predictions, or even the activations of the internal layers, after each batch .
Digging up the internet did not end with a solution. There are different solutions in different versions of TF (TF1 [1], Tf2 [2] and TF2.2[3] - the last is only partial). Other solutions offered to create a custom Keras training loop, while the default one uses internal Keras logic.
All the solutions stated before has two major drawbacks:
- They do not endure Tensorflow / Keras internal changes and versions divergence.
- They are too complicated.
I would not implement a Keras model from scratch, but the open-source that i'm relying on did. I did not want to make major changes in the project.
As usual, I was looking for a practical and easy solution to get predictions, that will last versions upgrades. I started digging into Keras code, looking for a way to get the predictions after each batch.
The major challenge relies on the fact that Keras does not call the train_step
in every batch, but create its own implementation of the function supplied. This results in that even prints
are not presented to the screen during training with the fit
function.
I found some potential places to change Keras code, that will help me get the predictions. However, this will work on other environments besides myn.
I kept searching for a solution, until I found the "hack":
Return the predictions via the metrics parameter.
It is may not be the original purpose of Keras'es author, but it works.
The first step is to create a custom simple MyModel
that inherits from Model
, and overrides the train_step
function:
xxxxxxxxxx
from tensorflow.keras.models import Model
from tensorflow.python.keras.engine.training import _minimize
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.eager import backprop
xxxxxxxxxx
class MyModel(Model):
#base code is taken from 'tensorflow.python.keras.engine.training'
function .
def train_step(self, data):
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
with backprop.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)
# For custom training steps, users can just write:
# trainable_variables = self.trainable_variables
# gradients = tape.gradient(loss, trainable_variables)
# self.optimizer.apply_gradients(zip(gradients, trainable_variables))
# The _minimize call does a few extra steps unnecessary in most cases,
# such as loss scaling and gradient clipping.
_minimize(self.distribute_strategy, tape, self.optimizer, loss,
self.trainable_variables)
self.compiled_metrics.update_state(y, y_pred, sample_weight)
metrics_ = {m.name: m.result() for m in self.metrics}
metrics_['pred']=y_pred
metrics_['inputs']=x
return metrics_
The second step is to implement a Keras callback:
xxxxxxxxxx
class PredictionHistory(Callback):
def __init__(self):
pass
def on_train_batch_end(self, epoch, logs={}):
#the predictions are in the logs parameter: logs['pred']
Now all you need is to call Keras fit
function with the parameter:
xxxxxxxxxx
callbacks=[PredictionHistory()]) #combined with other callbacks suck as checkpoint,tensorboard
And voilà! each batch ends with the predictions.
Note that you can add to the model's outputs each internal layer that you want, and thus getting any internal activation you want to investigate.
Next time, use native Tensorflow or PyTorch, ok?
Opinions expressed by DZone contributors are their own.
Trending
-
Scaling Site Reliability Engineering (SRE) Teams the Right Way
-
How To Use Pandas and Matplotlib To Perform EDA In Python
-
Never Use Credentials in a CI/CD Pipeline Again
-
How To Use Geo-Partitioning to Comply With Data Regulations and Deliver Low Latency Globally
Comments