Training a Handwritten Digits Classifier in Pytorch With Apache Cassandra Database
How to use data stored in a large-scale database as our training data and how to use that same database as a basic model registry.
Join the DZone community and get the full member experience.
Join For FreeHandwritten digit recognition is one of the classic tasks undertaken by students when learning the basics of Neural Networks and Computer Vision. The basic idea is to take a number of labeled images of handwritten digits and use those to train a neural network that is able to classify new unlabeled images. For this demo, we show how to use data stored in a large-scale database as our training data. We also explain how to use that same database as a basic model registry. This addition can enable model serving as well as potentially future retraining.
Introduction
MNIST is a set of datasets that share a particular format useful for educating students about neural networks while presenting them with diverse problems. The MNIST datasets for this demo are a collection of 28 by 28-pixel grayscale images as data and classifications 0-9 as potential labels. This demo works with the original MNIST handwritten digits dataset as well as the MNIST fashion dataset.
The use of both of these datasets will help calibrate models, testing whether they are affected by the domain of the classification or not. If a neural net is good at classifying digits but bad at classifying clothing and accessories, even though in this case the datasets have the same structure, it is evidence that something about the training or structure of the network contains knowledge on digits, or on handwriting, or is more suited to simple rather than complex shapes, etc.
Pytorch is a Python library that contains data, types, and methods for working with neural networks. We also make use of torchvision, a related library specifically meant for computer vision-related tasks. Pytorch works with data typed as Tensors and can define different types of layers that can be combined to do deep learning and gain advantages that single-type NNs cannot. Pytorch provides utilities that help us define, train, test, and predict using our models.
Astra is a managed database based on Apache Cassandra. It retains Cassandra's horizontal scalability and distributed nature and comes with a number of other advantages. Today, you can use Astra as the primary data store and as a primitive model registry.
Knowledge Base
Set-Up
For this demo, we use a database called cassio_db and a keyspace named mnist_digits. To create that database, select the “Databases” tab (shown on the left menu), then click on the “Create Database button” (shown on the right) and fill out the needed information for your database.
Once you’ve created your database, you’ll need to generate the Token or Secure Connect bundle to connect to your database with the connect tab. Choose the permissions that make the most sense for your use case. For this demo, there’s nothing wrong with choosing Database Administrator, but you can also go as simple as a Read/Write Service account to get the functionality you need.
Never share your token or bundle with anyone. It is a bundle of several pieces of data about your database and can be used to access it.
Reminder: for this demo, the assumed name of the keyspace is mnist_digits.
Establishing the Schema
Once the Keyspace has been created, we need to create the Tables that we will be using. Those tables are raw_train, raw_test, and models. We also create a raw_predict table holding data with no labels attached. Open the CQL Console for the database and enter these lines to create your tables.
CREATE TABLE mnist_digits.raw_train (id int PRIMARY KEY, label int, pixels list);
CREATE TABLE mnist_digits.raw_test (id int PRIMARY KEY, label int, pixels list);
CREATE TABLE mnist_digits.models (id uuid PRIMARY KEY, network blob, optimizer blob, upload_date timestamp, comments text);
CREATE TABLE mnist_digits.raw_predict (id int PRIMARY KEY, label int, pixels list);
Next, we need to create the resources we need to connect to the Astra database. Hit the Connect button on the UI and download the Secure Connect Bundle (SCB). Then, hit the “Create a Token” button to create a Database Administrator token and download the text that it returns.
Load the SCB into the environment and put the path to it in the auth.py files' first line between the single quotes. Put the generated ID (Client_ID) for the Database Admin token in the second line. Put the generated secret (Client_Secret) for the token in the third line.
Installing and Configuring the Raw Data
Run the data loader called load_raw_data.py using this line. Modify the train_split variable in the load_raw_data.py file if you want something other than an 80/20 train/test split.
python3 load_raw_data.py
This may take an hour or more to complete because there are close to 800 columns for each data sample (Note: this is why we asked you to select the high memory option for your GitPod). Once it is complete, make sure that the data was created by running these commands in the CQL Console of the Astra UI.
SELECT id, label from mnist_digits.raw_train limit 5;
SELECT id, label from mnist_digits.raw_test limit 5;
After that, you should be able to step through the model_training_full_sequence.ipynb notebook without issue, following the comments to train and store models.
Once the training is completed, we can retrieve the best-stored model from the database by querying the model with minimum loss or maximum accuracy. You can view an example of this query in the attached notebook. If we then load the stored model parameters into our network and run test()
we should get similar results for the loss and accuracy of the model.
Training
Administration
Adding New DataSets
This example uses the MNIST handwritten digits dataset. This dataset consists of a set of 22 by 22-pixel grayscale images depicting digits from 0-9, meant to be classified into those ten categories. This repo is easily modified to work with other datasets with this format. The most compatible will be other MNIST datasets, which promise to have the same 22 by 22 image size, the same grayscale pixel values, and the same ten categories. In fact, the fashion MNIST dataset here can be substituted almost exactly for the train and test CSV files included in the repo. If you switch the filenames in load_raw_data.py, the rest of the repo can be used as normal.
Using the Updated Model on New Data
In order to use this model on new data, the first step is to pull the row concerning the particular model you desire out of Astra. Then, you would use pkl.load on the network state object that was saved to turn it back into a dictionary object. Then, we create the Net class and network object the same way we did in the notebook. Next, we call network.load_state_dict and pass it to the state dictionary that we just loaded as input. Now, we have a network object with the same weights as the one we stored in Astra. We can then load new data from our test loader, whether using the test loader we create in the notebook, the data we placed in the raw+predict table, or new data that we load from somewhere else. Once we have the data and the model, we can call the network(data) to run the new data through the model and look through the results it gives for the predictions.
Customization
Changing the Structure of the Model
When we create the Net class, we define the structure of our model. In this example repo, we set up two convolutional layers, a dropout layer, and two linear layers. The convolutional layers take a number of 2D planes as input, perform a 2D convolution, and output a different number. Because our flat grayscale image fits into a single plane, our first Conv2d layer has a single input channel. Conv2d layers are specialized for image processing.
To use a traditional RGBimage as an input, we would up the number of input channels to 3, one for each color. The dropout layer randomly zeroes out some channels. The linear layers apply a linear transformation to incoming data. Because our final input has ten categories, the final linear layer has ten output layers. They return values between 0 and 1 for each value, roughly corresponding to a probability or confidence score, and we take the highest one and count that as the prediction. They are applied in the order of the first convolutional layer, second convolutional layer, dropout layer, first linear layer, and second linear layer. This order can be changed by modifying the order in which they are used in the Net classes forward method.
Changing the Details of How We Are Changing the Model
Before we train the model, we define a number of constants that change how that training takes place. The first is n_epochs, which defines how many training epochs we put the model through.
During each epoch, we feed in a number of training examples before stopping, at which point we test and save the model.
- Batch_size_train tells us how many of our training examples get fed to the model during each epoch.
- Batch_size_test defines how many examples are used for testing the model after training.
- Learning_rate defines a property of the optimizer, changing the backpropagation step and causing it to make bigger or smaller changes to the model weights.
- Momentum determines how much the changes to the model weights carry between the backpropagation step.
Because backpropagation uses calculus to determine model weight changes, the magnitude of those changes can be affected by the gradient slope of the previous backpropagation step.
Opinions expressed by DZone contributors are their own.
Comments