Classification
Learn how to train and use the CNN model for MNIST datasets.
We'll cover the following
Chapter Goals:
- Understand how hand-drawn digits are processed and passed into the model for classification
A. Model logistics
The run_model_setup
function below shows how to set up and train the CNN we’ve coded:
def run_model_setup(self, inputs, labels, is_training):logits = self.model_layers(inputs, is_training)# convert logits to probabilities with softmax activationself.probs = tf.nn.softmax(logits, name='probs')# round probabilitiesself.predictions = tf.math.argmax(self.probs, axis=-1, name='predictions')class_labels = tf.math.argmax(labels, axis=-1)# find which predictions were correctis_correct = tf.math.equal(self.predictions, class_labels)is_correct_float = tf.cast(is_correct,tf.float32)# compute ratio of correct to incorrect predictionsself.accuracy = tf.math.reduce_mean(is_correct_float)# train modelif self.is_training:labels_float = tf.cast(labels, tf.float32)# compute the loss using cross_entropycross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_float,logits=logits)self.loss = tf.math.reduce_mean(cross_entropy)# use adam to train modeladam = tf.compat.v1.train.AdamOptimizer()self.train_op = adam.minimize(self.loss, global_step=self.global_step)
For more explanation of the code, see the Machine Learning for Software Engineers course on Educative.
B. Real data
After training a model on the MNIST dataset, it is ready to classify real hand-drawn digits. Using the techniques from the Image Processing section, we can decode the hand-drawn image to obtain its pixel data (in grayscale format) and then resize it to the same dimensions as the MNIST image data. Since our model inputs have shape (batch_size, input_dim**2)
, we flatten the image’s pixel data and reshape it to (1, input_dim**2)
.
C. Classifying hand-drawn digits
The code below runs a digit classifier implemented in the backend. It will prompt you to draw a digit. The model will predict which digit you drew.
run_digit_recognizer()
Get hands-on with 1300+ tech skills courses.