Training Sampler
Create a Training Sampler object for training the decoder.
We'll cover the following
Chapter Goals:
- Learn about the decoding process during training
- Create a
TrainingSampler
object to use for decoding
A. Decoding during training
During training, we have access to both the input and output sequences of a training pair. This means that we can use the output sequence's ground truth tokens as input for the decoder.
We'll get into the specifics of TensorFlow decoding, i.e. using the decoder to generate model outputs, in later chapters. However, before we perform any decoding, we need to set up a Sampler
object. For training, the Sampler
object instance we use is the TrainingSampler
.
Below, we create a TrainingSampler
object for decoding during training.
import tensorflow_addons as tfa#Sampler is replaced with Samplersampler = tfa.seq2seq.TrainingSampler();
The TrainingSampler
object is initialized with the (embedded) ground truth sequences and the lengths of the ground truth sequences. Note that we use separate embedding models for the encoder input and the decoder input (i.e. ground truth tokens). This is because there are different word relationships in the input and output sequences for a seq2seq task, and sometimes the sequences can be completely different (e.g. machine translation).
Time to Code!
In this chapter you'll be filling in part of the create_decoder_Sampler
function, which creates the Sampler object for the decoder. Specifically, in this chapter you'll be focusing on creating the TrainingSampler
object.
We use the decoder differently depending on whether we're in training or inference mode. If we're in training mode, we use the decoder_inputs
as the ground truth tokens.
Before using the ground truth tokens, we need to convert them into embeddings. We provide a function called get_embeddings
, which returns the embeddings and sequence lengths. The function takes in the input sequences and the name of the embedding model as required arguments.
Inside the if
block, set the tuple dec_embeddings, dec_seq_lens
equal to self.get_embeddings
applied with decoder_inputs
and 'decoder_emb'
as the required arguments.
During training, we use the TrainingSampler
object. It takes in the ground truth embeddings and sequence lengths as required arguments.
Inside the if
block, set sampler
equal to tfa.seq2seq.sampler.TrainingSampler
.
import tensorflow as tfimport tensorflow_addons as tfa# Seq2seq modelclass Seq2SeqModel(object):def __init__(self, vocab_size, num_lstm_layers, num_lstm_units):self.vocab_size = vocab_size# Extended vocabulary includes start, stop tokenself.extended_vocab_size = vocab_size + 2self.num_lstm_layers = num_lstm_layersself.num_lstm_units = num_lstm_unitsself.tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=vocab_size)# Get embeddings for input/output sequencesdef get_embeddings(self, sequences, scope_name):with tf.compat.v1.variable_scope(scope_name,reuse=tf.compat.v1.AUTO_REUSE):cat_column = tf.compat.v1.feature_column \.categorical_column_with_identity('sequences', self.extended_vocab_size)embed_size = int(self.extended_vocab_size**0.25)embedding_column = tf.compat.v1.feature_column.embedding_column(cat_column, embed_size)seq_dict = {'sequences': sequences}embeddings= tf.compat.v1.feature_column \.input_layer(seq_dict, [embedding_column])sequence_lengths = tf.compat.v1.placeholder("int64", shape=(None,), name=scope_name+"/sinput_layer/sequence_length")return embeddings, tf.cast(sequence_lengths, tf.int32)# Create the sampler for decodingdef create_decoder_sampler(self, decoder_inputs, is_training, batch_size):if is_training:# CODE HEREpasselse:# IGNORE FOR NOWpassreturn sampler, dec_seq_lens
Get hands-on with 1300+ tech skills courses.