Final States
Learn about the final state output of an LSTM and BiLSTM.
Chapter Goals:
- Learn about the final states for an LSTM and BiLSTM
A. The encoder
In an encoder-decoder model for seq2seq tasks, there are two components: the encoder and the decoder. The encoder is responsible for extracting useful information from the input sequence. For NLP applications, the encoder is normally an LSTM or BiLSTM.
B. LSTM final state
When using an LSTM or BiLSTM encoder, we need to pass the final state of the encoder into the decoder. The final state of an LSTM in TensorFlow is represented by the LSTMStateTuple
object.
import tensorflow as tf# Input sequences (embedded)# Shape: (batch_size, max_seq_len, embed_dim)input_embeddings = tf.compat.v1.placeholder(tf.float32, shape=(None, None, 4))cell = tf.keras.layers.LSTMCell(5)rnn = tf.keras.layers.RNN(cell,return_state=True)output = rnn(input_embeddings)#With final_state = True , rnn will return 2 final state value that are stored in the output variable.#get the final states in "final_state"final_state = {output[1]} , {output[2]}# final_state is the output of our LSTM encoder.# it contains all the information about our input sequence,# which in this case is just a tf.compat.v1.Placeholder objectprint(final_state)
An LSTMStateTuple
object contains two important properties: the hidden state (c
) and the state output (h
). The hidden state represents the internal cell state (i.e. "memory") of the LSTM cell.
These two properties are represented by tensors with shape (batch_size, hidden_units)
.
C. Multi-layer final states
For a multi-layer LSTM, the final state output of tf.keras.layers.RNN
is a tuple containing the final state for each layer.
# Input sequences (embedded)# Shape: (batch_size, max_seq_len, embed_dim)input_embeddings = tf.compat.v1.placeholder(tf.float32, shape=(None, None, 4))cell1 = tf.keras.layers.LSTMCell(5)cell2 = tf.keras.layers.LSTMCell(8)cell = {cell1, cell2}multi_cell = tf.keras.layers.StackedRNNCells(cell)rnn = tf.keras.layers.RNN(multi_cell,return_state=True,dtype=tf.float32)outputs = rnn(input_embeddings)final_state_cell1 = outputs[1]final_state_cell2 = outputs[2]print(final_state_cell1) # layer 1print(final_state_cell2) # layer 2
D. BiLSTM final state
The final state of a BiLSTM consists of the final state for the forward LSTM and the final state for the backward LSTM. This corresponds to a tuple containing two LSTMStateTuple
objects.
import tensorflow as tf# Input sequences (embedded)# Shape: (batch_size, max_seq_len, embed_dim)input_embeddings = tf.compat.v1.placeholder(tf.float32, shape=(None, None, 4))cell = tf.keras.layers.LSTMCell(5)rnn = tf.keras.layers.RNN(cell,return_state=True,go_backwards=True,dtype=tf.float32)Bi_rnn = tf.keras.layers.Bidirectional(rnn,merge_mode='concat')outputs = Bi_rnn(input_embeddings)forward_final_state = outputs[1] , outputs[2]backward_final_state = outputs[3] , outputs[4]print(forward_final_state)print(backward_final_state)
When the BiLSTM has multiple layers, the final_states
output of tf.keras.layers.Bidirectional
is a tuple containing the final forward and backward states for each layer.
import tensorflow as tf# Input sequences (embedded)# Shape: (batch_size, max_seq_len, embed_dim)input_embeddings = tf.compat.v1.placeholder(tf.float32, shape=(None, None, 4))#multiple cellscell1 = tf.keras.layers.LSTMCell(5)cell2 = tf.keras.layers.LSTMCell(9)cell = {cell1, cell2}# staking cellsmulti_cell = tf.keras.layers.StackedRNNCells(cell)#createing RNN objectrnn = tf.keras.layers.RNN(multi_cell,return_state=True,go_backwards=True)Bi_rnn = tf.keras.layers.Bidirectional(rnn,merge_mode='concat')outputs = Bi_rnn(input_embeddings)fw_final_state_cell1 = outputs[1]print(fw_final_state_cell1)fw_final_state_cell2 = outputs[2]print(fw_final_state_cell2)bw_final_state_cell1 = outputs[3]print(bw_final_state_cell1)bw_final_state_cell2 = outputs[4]print(bw_final_state_cell2)
E. Combining forward and backward
In order to use BiLSTM final states in an encoder-decoder model, we need to combine the forward and backward states. This is because the decoder portion utilizes a regular LSTM, which only has a forward direction.
Luckily, combining the forward and backward states is rather simple. The main thing we need to do is concatenate the hidden state and state output of both the forward and backward states.
import tensorflow as tf# Forward state of single-layer BiLSTM final statesfw_c = forward_final_state[0]fw_h = forward_final_state[1]# Backward state of single-layer BiLSTM final statesbw_c = backward_final_state[0]bw_h = backward_final_state[1]# Concatenate along final axisfinal_c = tf.concat([fw_c, bw_c], -1)final_h = tf.concat([fw_h, bw_h], -1)
We'll describe in the next chapter how to convert the combined values into an LSTMStateTuple
object.
Time to Code!
In this chapter you'll be completing the get_bi_state_parts
function, which combines the forward and backward final states for a layer in the BiLSTM. The function is used as a helper in the Seq2Seq
object's encoder
function, which creates the encoder portion of the model.
When combining the forward and backward states, we use the tf.concat
function to concatenate along the final axis. We perform separate concatenations for the hidden state (c
) and the state output (h
).
Set bi_state_c
equal to tf.concat
applied with a list containing state_fw.c
(state_fw[0]) and state_bw.c
(state_bw[0]) (in that order), and -1
as the axis.
Set bi_state_h
equal to tf.concat
applied with a list containing state_fw.h
(state_fw[1]) and state_bw.h
(state_bw[1]) (in that order), and -1
as the axis.
Return a tuple containing bi_state_c
as the first element and bi_state_h
as the second.
import tensorflow as tf# Get c and h vectors for bidirectional LSTM final statesdef ref_get_bi_state_parts(state_fw, state_bw):#CODE HEREpass# Seq2seq modelclass Seq2SeqModel(object):def __init__(self, vocab_size, num_lstm_layers, num_lstm_units):self.vocab_size = vocab_size# Extended vocabulary includes start, stop tokenself.extended_vocab_size = vocab_size + 2self.num_lstm_layers = num_lstm_layersself.num_lstm_units = num_lstm_unitsself.tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=vocab_size)def make_lstm_cell(self, dropout_keep_prob, num_units):cell = tf.keras.layers.LSTMCell(num_units, dropout=dropout_keep_prob)return cell# Create multi-layer LSTMdef stacked_lstm_cells(self, is_training, num_units):dropout_keep_prob = 0.5 if is_training else 1.0cell_list = [self.make_lstm_cell(dropout_keep_prob, num_units) for i in range(self.num_lstm_layers)]cell = tf.keras.layers.StackedRNNCells(cell_list)return cell# Get embeddings for input/output sequencesdef get_embeddings(self, sequences, scope_name):with tf.compat.v1.variable_scope(scope_name,reuse=tf.compat.v1.AUTO_REUSE):cat_column = tf.compat.v1.feature_column \.categorical_column_with_identity('sequences', self.extended_vocab_size)embed_size = int(self.extended_vocab_size**0.25)embedding_column = tf.compat.v1.feature_column.embedding_column(cat_column, embed_size)seq_dict = {'sequences': sequences}embeddings= tf.compat.v1.feature_column \.input_layer(seq_dict, [embedding_column])sequence_lengths = tf.compat.v1.placeholder("int64", shape=(None,), name=scope_name+"/sinput_layer/sequence_length")return embeddings, tf.cast(sequence_lengths, tf.int32)# Create the encoder for the modeldef encoder(self, encoder_inputs, is_training):input_embeddings, input_seq_lens = self.get_embeddings(encoder_inputs, 'encoder_emb')cell = self.stacked_lstm_cells(is_training, self.num_lstm_units)rnn = tf.keras.layers.RNN(cell,return_sequences=True,return_state=True,go_backwards=True,dtype=tf.float32)Bi_rnn = tf.keras.layers.Bidirectional(rnn,merge_mode='concat')input_embeddings = tf.reshape(input_embeddings, [-1,-1,2])outputs = Bi_rnn(input_embeddings)states_fw = [ outputs[i] for i in range(1,self.num_lstm_layers+1)]states_bw = [ outputs[i] for i in range(self.num_lstm_layers+1,len(outputs))]for i in range(self.num_lstm_layers):bi_state_c, bi_state_h = ref_get_bi_state_parts(states_fw[i], states_bw[i])
Get hands-on with 1300+ tech skills courses.