Classification

Learn how to train and use the CNN model for MNIST datasets.

Chapter Goals:

  • Understand how hand-drawn digits are processed and passed into the model for classification

A. Model logistics

The run_model_setup function below shows how to set up and train the CNN we’ve coded:

Press + to interact
def run_model_setup(self, inputs, labels, is_training):
logits = self.model_layers(inputs, is_training)
# convert logits to probabilities with softmax activation
self.probs = tf.nn.softmax(logits, name='probs')
# round probabilities
self.predictions = tf.math.argmax(
self.probs, axis=-1, name='predictions')
class_labels = tf.math.argmax(labels, axis=-1)
# find which predictions were correct
is_correct = tf.math.equal(
self.predictions, class_labels)
is_correct_float = tf.cast(
is_correct,
tf.float32)
# compute ratio of correct to incorrect predictions
self.accuracy = tf.math.reduce_mean(
is_correct_float)
# train model
if self.is_training:
labels_float = tf.cast(
labels, tf.float32)
# compute the loss using cross_entropy
cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=labels_float,
logits=logits)
self.loss = tf.math.reduce_mean(
cross_entropy)
# use adam to train model
adam = tf.compat.v1.train.AdamOptimizer()
self.train_op = adam.minimize(
self.loss, global_step=self.global_step)

For more explanation of the code, see the Machine Learning for Software Engineers course on Educative.

B. Real data

After training a model on the MNIST dataset, it is ready to classify real hand-drawn digits. Using the techniques from the Image Processing section, we can decode the hand-drawn image to obtain its pixel data (in grayscale format) and then resize it to the same dimensions as the MNIST image data. Since our model inputs have shape (batch_size, input_dim**2), we flatten the image’s pixel data and reshape it to (1, input_dim**2).

C. Classifying hand-drawn digits

The code below runs a digit classifier implemented in the backend. It will prompt you to draw a digit. The model will predict which digit you drew.

Press + to interact
run_digit_recognizer()

Get hands-on with 1300+ tech skills courses.