Decoder Object

Learn about the decoder object for the encoder-decoder model.

Chapter Goals:

  • Convert the encoder’s final state into the proper format for decoding with attention
  • Create a BasicDecoder object to use for decoding

A. Creating the initial state

The final state from the encoder is a tuple containing an LSTMStateTuple object for each layer of the BiLSTM. However, if we want to use this as the initial state for an attention-wrapped decoder, we need to convert it into an AttentionWrapperState.

For the conversion we ned to use get_initial_state with required arguments inputs , batch_size and tf.float32 as dtype. The result is the initial state for the attention-wrapped decoder.

Below we demonstrate how to create the decoder’s initial state from the encoder’s final state.

Press + to interact
import tensorflow as tf
batch_size = tf.constant(10)
initial_state = dec_cell.get_initial_state(inputs,batch_size=batch_size , dtype = tf.float32)

B. The BasicDecoder object

The decoder object that we use for decoding is the BasicDecoder. To create an instance of the BasicDecoder object, we need to pass in the decoder cell and sampler object as required arguments.

Press + to interact
import tensorflow as tf
decoder = tfa.seq2seq.BasicDecoder(dec_cell, sampler)

The BasicDecoder constructor has a keyword argument called output_layer, which can be used to apply a fully-connected layer to the model’s outputs. This is a nice shortcut when calculating the model’s logits.

Press + to interact
import tensorflow as tf
num_units = 500 # extended vocab size
projection_layer = tf.keras.layers.Dense(24)
decoder = tfa.seq2seq.BasicDecoder(dec_cell, sampler, output_layer=projection_layer)

Instead of using the tf.keras.layers.Dense function, we create a fully-connected layer object using the tf.keras.layers.Dense constructor. The required argument for initialization is the number of hidden units. In the example above, we set it equal to the extended vocabulary size, which gives us the proper shape for the logits.

Time to Code!

In this chapter you’ll be completing the create_basic_decoder function, which is used in the model’s decoder function to create a BasicDecoder object from a decoder LSTM cell and Helper object.

When creating the decoder object, we’ll apply a projection layer to the end, which will calculate the model’s logits directly.

Set projection_layer equal to tf.keras.layers.Dense initialized with extended_vocab_size.

Before we can create the decoder object, we need to make sure that the initial state for the decoder is in the correct format.

Set initial_state equal to dec_cell.get_initial_state applied with input,batch_sizeand dtype.

We need to convert integer batch_size into 0Dtensor by using tf.constant().

Set batch_size equal to tf.constant applied with batch_size

We can now create the decoder object using the BasicDecoder constructor. We’ll use dec_cell, sampler, and initial_state as the required arguments for initialization.

Set decoder equal to tfa.seq2seq.sampler.BasicDecoder initialized with the specified required arguments as well as projection_layer for the output_layer keyword argument. Then return decoder.

Press + to interact
import tensorflow as tf
def create_basic_decoder(enc_outputs , extended_vocab_size, batch_size, final_state, dec_cell, sampler):
# CODE HERE
pass
# Seq2seq model
class Seq2SeqModel(object):
def __init__(self, vocab_size, num_lstm_layers, num_lstm_units):
self.vocab_size = vocab_size
# Extended vocabulary includes start, stop token
self.extended_vocab_size = vocab_size + 2
self.num_lstm_layers = num_lstm_layers
self.num_lstm_units = num_lstm_units
self.tokenizer = tf.keras.preprocessing.text.Tokenizer(
num_words=vocab_size)
def make_lstm_cell(self, dropout_keep_prob, num_units):
cell = tf.keras.layers.LSTMCell(num_units, dropout=dropout_keep_prob )
return cell
# Create multi-layer LSTM
def stacked_lstm_cells(self, is_training, num_units):
dropout_keep_prob = 0.5 if is_training else 1.0
cell_list = [self.make_lstm_cell(dropout_keep_prob, num_units) for i in range(self.num_lstm_layers)]
cell = tf.keras.layers.StackedRNNCells(cell_list)
return cell
# Get embeddings for input/output sequences
def get_embeddings(self, sequences, scope_name):
with tf.compat.v1.variable_scope(scope_name,reuse=tf.compat.v1.AUTO_REUSE):
cat_column = tf.compat.v1.feature_column \
.categorical_column_with_identity(
'sequences', self.extended_vocab_size)
embed_size = int(self.extended_vocab_size**0.25)
embedding_column = tf.compat.v1.feature_column.embedding_column(
cat_column, embed_size)
seq_dict = {'sequences': sequences}
embeddings= tf.compat.v1.feature_column \
.input_layer(
seq_dict, [embedding_column])
sequence_lengths = tf.compat.v1.placeholder("int64", shape=(None,), name=scope_name+"/sinput_layer/sequence_length")
return embeddings, tf.cast(sequence_lengths, tf.int32)
# sampler funtion to combine BiLSTM encoder outputs
def combine_enc_outputs(self, enc_outputs):
enc_outputs_fw, enc_outputs_bw = enc_outputs
return tf.concat([enc_outputs_fw, enc_outputs_bw], -1)
# Create the stacked LSTM cells for the decoder
def create_decoder_cell(self, enc_outputs, input_seq_lens, is_training):
num_decode_units = self.num_lstm_units * 2
dec_cell = self.stacked_lstm_cells(is_training, num_decode_units)
combined_enc_outputs = self.combine_enc_outputs(enc_outputs)
attention_mechanism = tfa.seq2seq.LuongAttention(
num_decode_units, combined_enc_outputs,
memory_sequence_length=input_seq_lens)
dec_cell = tfa.seq2seq.AttentionWrapper(
dec_cell, attention_mechanism,
attention_layer_size=num_decode_units)
return dec_cell
# Create the sampler for decoding
def create_decoder_sampler(self, decoder_inputs, is_training, batch_size):
if is_training:
dec_embeddings, dec_seq_lens = self.get_embeddings(decoder_inputs, 'decoder_emb')
sampler = tfa.seq2seq.sampler.TrainingSampler()
else:
# IGNORE FOR NOW
pass
return sampler, dec_seq_lens
# Create the decoder for the model
def Decoder(self, enc_outputs, input_seq_lens, final_state, batch_size, sampler, dec_seq_lens):
is_training = True
dec_cell = self.create_decoder_cell(enc_outputs, input_seq_lens, is_training)
decoder = create_basic_decoder(enc_outputs , self.extended_vocab_size, batch_size, final_state, dec_cell, sampler)

Get hands-on with 1300+ tech skills courses.