Decoder Object
Learn about the decoder object for the encoder-decoder model.
We'll cover the following
Chapter Goals:
- Convert the encoder’s final state into the proper format for decoding with attention
- Create a
BasicDecoder
object to use for decoding
A. Creating the initial state
The final state from the encoder is a tuple containing an LSTMStateTuple
object for each layer of the BiLSTM. However, if we want to use this as the initial state for an attention-wrapped decoder, we need to convert it into an AttentionWrapperState
.
For the conversion we ned to use get_initial_state
with required arguments inputs
, batch_size
and tf.float32
as dtype
. The result is the initial state for the attention-wrapped decoder.
Below we demonstrate how to create the decoder’s initial state from the encoder’s final state.
import tensorflow as tfbatch_size = tf.constant(10)initial_state = dec_cell.get_initial_state(inputs,batch_size=batch_size , dtype = tf.float32)
B. The BasicDecoder object
The decoder object that we use for decoding is the BasicDecoder
. To create an instance of the BasicDecoder
object, we need to pass in the decoder cell and sampler object as required arguments.
import tensorflow as tfdecoder = tfa.seq2seq.BasicDecoder(dec_cell, sampler)
The BasicDecoder
constructor has a keyword argument called output_layer
, which can be used to apply a fully-connected layer to the model’s outputs. This is a nice shortcut when calculating the model’s logits.
import tensorflow as tfnum_units = 500 # extended vocab sizeprojection_layer = tf.keras.layers.Dense(24)decoder = tfa.seq2seq.BasicDecoder(dec_cell, sampler, output_layer=projection_layer)
Instead of using the tf.keras.layers.Dense
function, we create a fully-connected layer object using the tf.keras.layers.Dense
constructor. The required argument for initialization is the number of hidden units. In the example above, we set it equal to the extended vocabulary size, which gives us the proper shape for the logits.
Time to Code!
In this chapter you’ll be completing the create_basic_decoder
function, which is used in the model’s decoder
function to create a BasicDecoder
object from a decoder LSTM cell and Helper object.
When creating the decoder object, we’ll apply a projection layer to the end, which will calculate the model’s logits directly.
Set projection_layer
equal to tf.keras.layers.Dense
initialized with extended_vocab_size
.
Before we can create the decoder object, we need to make sure that the initial state for the decoder is in the correct format.
Set initial_state
equal to dec_cell.get_initial_state
applied with input
,batch_size
and dtype
.
We need to convert integer batch_size
into 0D
tensor by using tf.constant()
.
Set batch_size
equal to tf.constant
applied with batch_size
We can now create the decoder object using the BasicDecoder
constructor. We’ll use dec_cell
, sampler
, and initial_state
as the required arguments for initialization.
Set decoder
equal to tfa.seq2seq.sampler.BasicDecoder
initialized with the specified required arguments as well as projection_layer
for the output_layer
keyword argument. Then return decoder
.
import tensorflow as tfdef create_basic_decoder(enc_outputs , extended_vocab_size, batch_size, final_state, dec_cell, sampler):# CODE HEREpass# Seq2seq modelclass Seq2SeqModel(object):def __init__(self, vocab_size, num_lstm_layers, num_lstm_units):self.vocab_size = vocab_size# Extended vocabulary includes start, stop tokenself.extended_vocab_size = vocab_size + 2self.num_lstm_layers = num_lstm_layersself.num_lstm_units = num_lstm_unitsself.tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=vocab_size)def make_lstm_cell(self, dropout_keep_prob, num_units):cell = tf.keras.layers.LSTMCell(num_units, dropout=dropout_keep_prob )return cell# Create multi-layer LSTMdef stacked_lstm_cells(self, is_training, num_units):dropout_keep_prob = 0.5 if is_training else 1.0cell_list = [self.make_lstm_cell(dropout_keep_prob, num_units) for i in range(self.num_lstm_layers)]cell = tf.keras.layers.StackedRNNCells(cell_list)return cell# Get embeddings for input/output sequencesdef get_embeddings(self, sequences, scope_name):with tf.compat.v1.variable_scope(scope_name,reuse=tf.compat.v1.AUTO_REUSE):cat_column = tf.compat.v1.feature_column \.categorical_column_with_identity('sequences', self.extended_vocab_size)embed_size = int(self.extended_vocab_size**0.25)embedding_column = tf.compat.v1.feature_column.embedding_column(cat_column, embed_size)seq_dict = {'sequences': sequences}embeddings= tf.compat.v1.feature_column \.input_layer(seq_dict, [embedding_column])sequence_lengths = tf.compat.v1.placeholder("int64", shape=(None,), name=scope_name+"/sinput_layer/sequence_length")return embeddings, tf.cast(sequence_lengths, tf.int32)# sampler funtion to combine BiLSTM encoder outputsdef combine_enc_outputs(self, enc_outputs):enc_outputs_fw, enc_outputs_bw = enc_outputsreturn tf.concat([enc_outputs_fw, enc_outputs_bw], -1)# Create the stacked LSTM cells for the decoderdef create_decoder_cell(self, enc_outputs, input_seq_lens, is_training):num_decode_units = self.num_lstm_units * 2dec_cell = self.stacked_lstm_cells(is_training, num_decode_units)combined_enc_outputs = self.combine_enc_outputs(enc_outputs)attention_mechanism = tfa.seq2seq.LuongAttention(num_decode_units, combined_enc_outputs,memory_sequence_length=input_seq_lens)dec_cell = tfa.seq2seq.AttentionWrapper(dec_cell, attention_mechanism,attention_layer_size=num_decode_units)return dec_cell# Create the sampler for decodingdef create_decoder_sampler(self, decoder_inputs, is_training, batch_size):if is_training:dec_embeddings, dec_seq_lens = self.get_embeddings(decoder_inputs, 'decoder_emb')sampler = tfa.seq2seq.sampler.TrainingSampler()else:# IGNORE FOR NOWpassreturn sampler, dec_seq_lens# Create the decoder for the modeldef Decoder(self, enc_outputs, input_seq_lens, final_state, batch_size, sampler, dec_seq_lens):is_training = Truedec_cell = self.create_decoder_cell(enc_outputs, input_seq_lens, is_training)decoder = create_basic_decoder(enc_outputs , self.extended_vocab_size, batch_size, final_state, dec_cell, sampler)
Get hands-on with 1300+ tech skills courses.