Save For Inference
Save a condensed model to be used for inference on real-time data.
We'll cover the following
Chapter Goals:
- Learn how to save a model for inference
A. Deploying a model
As mentioned in chapter 4, a saved model checkpoint consists of three files: .data
, .index
, and .meta
. Since the .meta
file contains the entire computation graph structure, which includes all the data in the training dataset, it can get quite large. The large file size becomes an issue when deploying an inference model.
An inference model is a fully trained and evaluated model used to make predictions on real-time data. When we deploy an inference model for production, we don’t usually deploy the code used to build the model, either for proprietary reasons or because there are too many auxiliary code files. When we don’t have the code that sets up the inference graph, we need a separate file that specifies the computation graph’s structure.
B. Inference graph
The main issue with the .meta
file is that it contains many unnecessary portions of the computation graph, with respect to inference. For inference, we only need a tf.compat.v1.placeholder
to represent the input data. We also don’t need any parts of the computation graph specific to training, such as the loss calculation or dataset.
So instead of using a training checkpoint for the inference model, we create a bare-bones computation graph, consisting only of the input placeholder and the computations necessary to obtain a prediction.
If the model just finished training (meaning the training computation graph is still in memory), it’s necessary to use tf.compat.v1.reset_default_graph
prior to building the inference graph, in order to avoid graph conflicts.
import tensorflow as tfinputs = tf.compat.v1.placeholder(tf.float32, shape=(None, 3), name='inputs')logits = tf.keras.layers.Dense( 1, name='logits')(inputs)try:logits = tf.keras.layers.Dense(1, name='logits')(inputs)except ValueError: # Need to reset graphtf.inputs,reset_default_graph()inputs = tf.compat.v1.placeholder(tf.float32, shape=(None, 3), name='inputs')logits = tf.keras.layers.Dense( 1, name='logits')(inputs)print(logits)
C. Saving the model
We save the inference model using tf.compat.v1.saved_model.simple_save
. The function’s first argument is a tf.compat.v1.Session
object and the second argument is the path to the directory where we save the inference model. Note that the directory with which we save the inference model must not already exist.
The third argument is a dictionary containing the input tensor(s) as values, with string labels as keys. The fourth required argument is also a dictionary with string keys, but for the output tensor(s), e.g. the model prediction.
The function will save a file called saved_model.pb
and a directory called variables
in the specified directory.
ls inference_dir
In the example above, inference_dir
is the directory where tf.compat.v1.saved_model.simple_save
saved the inference model. The saved_model.pb
file contains the bare-bones computation graph, and is much smaller than the corresponding .meta
file. The variables
directory contains the model’s saved parameters.
In the next chapter, you’ll learn how to restore the inference model and make predictions.
Time to Code!
In this chapter you’ll complete the save_inference_graph
function, which saves the model’s computation graph for inference. The function is already filled with code that restores the model state from a checkpoint.
The input dictionary for the inference graph contains input_placeholder
as its only value, which represents the input data for the inference graph.
Set input_dict
equal to a dictionary with a single key, 'inputs'
, that maps to input_placeholder
.
The output dictionary for the inference graph contains self.predictions
as its only value. The corresponding key is 'predictions'
.
Set output_dict
equal to a dictionary consisting of the specified key-value pair.
After creating the dictionaries for the inference graph’s input and output, we can save the model using tf.compat.v1.saved_model.simple_save
.
Call the specified function with sess
, export_dir
, input_dict
, and output_dict
as the four input arguments.
import numpy as npimport tensorflow as tfclass ClassificationModel(object):def __init__(self, output_size):self.output_size = output_size# Save the model's computation graph for inferencedef save_inference_graph(self, sess, ckpt_dir, input_placeholder, export_dir):ckpt = tf.compat.v1.train.get_checkpoint_state(ckpt_dir)if ckpt is not None:saver = tf.compat.v1.train.Saver()saver.restore(sess, ckpt.model_checkpoint_path)#CODE HEREpass# See the "Efficient Data Processing Techniques" section for detailsdef dataset_from_numpy(self, input_data, batch_size, labels=None, is_training=True, num_epochs=None):dataset_input = input_data if labels is None else (input_data, labels)dataset = tf.compat.v1.data.Dataset.from_tensor_slices(dataset_input)if is_training:dataset = dataset.shuffle(len(input_data)).repeat(num_epochs)return dataset.batch(batch_size)# See the "Machine Learning for Software Engineers" course on Educativedef run_model_setup(self, inputs, labels, hidden_layers, is_training, calculate_accuracy=True):layer = inputsfor num_nodes in hidden_layers:layer = tf.keras.layers.Dense( num_nodes,activation=tf.nn.relu)(layer)logits = tf.keras.layers.Dense( self.output_size,name='logits')(layer)self.probs = tf.compat.v1.nn.softmax(logits, name='probs')self.predictions = tf.math.argmax(self.probs, axis=-1, name='predictions')if calculate_accuracy:class_labels = tf.math.argmax(labels, axis=-1)is_correct = tf.equal(self.predictions, class_labels)is_correct_float = tf.cast(is_correct,tf.float32)self.accuracy = tf.math.reduce_mean(is_correct_float)if labels is not None:labels_float = tf.cast(labels, tf.float32)cross_entropy = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(labels=labels_float,logits=logits)self.loss = tf.math.reduce_mean(cross_entropy)if is_training:adam = tf.compat.v1.train.AdamOptimizer()self.train_op = adam.minimize(self.loss, global_step=self.global_step)# Run training of the classification modeldef run_model_training(self, input_data, labels, hidden_layers, batch_size, num_epochs, ckpt_dir):self.global_step = tf.compat.v1.train.get_or_create_global_step()dataset = self.dataset_from_numpy(input_data, batch_size,labels=labels, num_epochs=num_epochs)iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)inputs, labels = iterator.get_next()self.run_model_setup(inputs, labels, hidden_layers, True)tf.summary.scalar('accuracy', self.accuracy)tf.summary.histogram('inputs', inputs)log_vals = {'loss': self.loss, 'step': self.global_step}logging_hook = tf.compat.v1.train.LoggingTensorHook(log_vals, every_n_iter=1000)nan_hook = tf.compat.v1.train.NanTensorHook(self.loss)hooks = [nan_hook, logging_hook]with tf.compat.v1.train.MonitoredTrainingSession(checkpoint_dir=ckpt_dir,hooks=hooks) as sess:while not sess.should_stop():sess.run(self.train_op)
Get hands-on with 1300+ tech skills courses.