Final States

Chapter Goals:

  • Learn about the final states for an LSTM and BiLSTM

A. The encoder

In an encoder-decoder model for seq2seq tasks, there are two components: the encoder and the decoder. The encoder is responsible for extracting useful information from the input sequence. For NLP applications, the encoder is normally an LSTM or BiLSTM.

B. LSTM final state

When using an LSTM or BiLSTM encoder, we need to pass the final state of the encoder into the decoder. The final state of an LSTM in TensorFlow is represented by the LSTMStateTuple object.

Press + to interact
import tensorflow as tf
# Input sequences (embedded)
# Shape: (batch_size, max_seq_len, embed_dim)
input_embeddings = tf.compat.v1.placeholder(
tf.float32, shape=(None, None, 4))
cell = tf.keras.layers.LSTMCell(5)
rnn = tf.keras.layers.RNN(
cell,
return_state=True)
output = rnn(input_embeddings)
#With final_state = True , rnn will return 2 final state value that are stored in the output variable.
#get the final states in "final_state"
final_state = {output[1]} , {output[2]}
# final_state is the output of our LSTM encoder.
# it contains all the information about our input sequence,
# which in this case is just a tf.compat.v1.Placeholder object
print(final_state)

An LSTMStateTuple object contains two important properties: the hidden state (c) and the state output (h). The hidden state represents the internal cell state (i.e. "memory") of the LSTM cell.

These two properties are represented by tensors with shape (batch_size, hidden_units).

C. Multi-layer final states

For a multi-layer LSTM, the final state output of tf.keras.layers.RNN is a tuple containing the final state for each layer.

Press + to interact
# Input sequences (embedded)
# Shape: (batch_size, max_seq_len, embed_dim)
input_embeddings = tf.compat.v1.placeholder(
tf.float32, shape=(None, None, 4))
cell1 = tf.keras.layers.LSTMCell(5)
cell2 = tf.keras.layers.LSTMCell(8)
cell = {cell1, cell2}
multi_cell = tf.keras.layers.StackedRNNCells(cell)
rnn = tf.keras.layers.RNN(
multi_cell,
return_state=True,
dtype=tf.float32)
outputs = rnn(input_embeddings)
final_state_cell1 = outputs[1]
final_state_cell2 = outputs[2]
print(final_state_cell1) # layer 1
print(final_state_cell2) # layer 2

D. BiLSTM final state

The final state of a BiLSTM consists of the final state for the forward LSTM and the final state for the backward LSTM. This corresponds to a tuple containing two LSTMStateTuple objects.

Press + to interact
import tensorflow as tf
# Input sequences (embedded)
# Shape: (batch_size, max_seq_len, embed_dim)
input_embeddings = tf.compat.v1.placeholder(
tf.float32, shape=(None, None, 4))
cell = tf.keras.layers.LSTMCell(5)
rnn = tf.keras.layers.RNN(
cell,
return_state=True,
go_backwards=True,
dtype=tf.float32)
Bi_rnn = tf.keras.layers.Bidirectional(
rnn,
merge_mode='concat'
)
outputs = Bi_rnn(input_embeddings)
forward_final_state = outputs[1] , outputs[2]
backward_final_state = outputs[3] , outputs[4]
print(forward_final_state)
print(backward_final_state)

When the BiLSTM has multiple layers, the final_states output of tf.keras.layers.Bidirectional is a tuple containing the final forward and backward states for each layer.

Press + to interact
import tensorflow as tf
# Input sequences (embedded)
# Shape: (batch_size, max_seq_len, embed_dim)
input_embeddings = tf.compat.v1.placeholder(
tf.float32, shape=(None, None, 4))
#multiple cells
cell1 = tf.keras.layers.LSTMCell(5)
cell2 = tf.keras.layers.LSTMCell(9)
cell = {cell1, cell2}
# staking cells
multi_cell = tf.keras.layers.StackedRNNCells(cell)
#createing RNN object
rnn = tf.keras.layers.RNN(
multi_cell,
return_state=True,
go_backwards=True)
Bi_rnn = tf.keras.layers.Bidirectional(
rnn,
merge_mode='concat'
)
outputs = Bi_rnn(input_embeddings)
fw_final_state_cell1 = outputs[1]
print(fw_final_state_cell1)
fw_final_state_cell2 = outputs[2]
print(fw_final_state_cell2)
bw_final_state_cell1 = outputs[3]
print(bw_final_state_cell1)
bw_final_state_cell2 = outputs[4]
print(bw_final_state_cell2)

E. Combining forward and backward

In order to use BiLSTM final states in an encoder-decoder model, we need to combine the forward and backward states. This is because the decoder portion utilizes a regular LSTM, which only has a forward direction.

Luckily, combining the forward and backward states is rather simple. The main thing we need to do is concatenate the hidden state and state output of both the forward and backward states.

Press + to interact
import tensorflow as tf
# Forward state of single-layer BiLSTM final states
fw_c = forward_final_state[0]
fw_h = forward_final_state[1]
# Backward state of single-layer BiLSTM final states
bw_c = backward_final_state[0]
bw_h = backward_final_state[1]
# Concatenate along final axis
final_c = tf.concat([fw_c, bw_c], -1)
final_h = tf.concat([fw_h, bw_h], -1)

We'll describe in the next chapter how to convert the combined values into an LSTMStateTuple object.

Time to Code!

In this chapter you'll be completing the get_bi_state_parts function, which combines the forward and backward final states for a layer in the BiLSTM. The function is used as a helper in the Seq2Seq object's encoder function, which creates the encoder portion of the model.

When combining the forward and backward states, we use the tf.concat function to concatenate along the final axis. We perform separate concatenations for the hidden state (c) and the state output (h).

Set bi_state_c equal to tf.concat applied with a list containing state_fw.c(state_fw[0]) and state_bw.c(state_bw[0]) (in that order), and -1 as the axis.

Set bi_state_h equal to tf.concat applied with a list containing state_fw.h(state_fw[1]) and state_bw.h(state_bw[1]) (in that order), and -1 as the axis.

Return a tuple containing bi_state_c as the first element and bi_state_h as the second.

Press + to interact
import tensorflow as tf
# Get c and h vectors for bidirectional LSTM final states
def ref_get_bi_state_parts(state_fw, state_bw):
#CODE HERE
pass
# Seq2seq model
class Seq2SeqModel(object):
def __init__(self, vocab_size, num_lstm_layers, num_lstm_units):
self.vocab_size = vocab_size
# Extended vocabulary includes start, stop token
self.extended_vocab_size = vocab_size + 2
self.num_lstm_layers = num_lstm_layers
self.num_lstm_units = num_lstm_units
self.tokenizer = tf.keras.preprocessing.text.Tokenizer(
num_words=vocab_size)
def make_lstm_cell(self, dropout_keep_prob, num_units):
cell = tf.keras.layers.LSTMCell(num_units, dropout=dropout_keep_prob)
return cell
# Create multi-layer LSTM
def stacked_lstm_cells(self, is_training, num_units):
dropout_keep_prob = 0.5 if is_training else 1.0
cell_list = [self.make_lstm_cell(dropout_keep_prob, num_units) for i in range(self.num_lstm_layers)]
cell = tf.keras.layers.StackedRNNCells(cell_list)
return cell
# Get embeddings for input/output sequences
def get_embeddings(self, sequences, scope_name):
with tf.compat.v1.variable_scope(scope_name,reuse=tf.compat.v1.AUTO_REUSE):
cat_column = tf.compat.v1.feature_column \
.categorical_column_with_identity(
'sequences', self.extended_vocab_size)
embed_size = int(self.extended_vocab_size**0.25)
embedding_column = tf.compat.v1.feature_column.embedding_column(
cat_column, embed_size)
seq_dict = {'sequences': sequences}
embeddings= tf.compat.v1.feature_column \
.input_layer(
seq_dict, [embedding_column])
sequence_lengths = tf.compat.v1.placeholder("int64", shape=(None,), name=scope_name+"/sinput_layer/sequence_length")
return embeddings, tf.cast(sequence_lengths, tf.int32)
# Create the encoder for the model
def encoder(self, encoder_inputs, is_training):
input_embeddings, input_seq_lens = self.get_embeddings(encoder_inputs, 'encoder_emb')
cell = self.stacked_lstm_cells(is_training, self.num_lstm_units)
rnn = tf.keras.layers.RNN(
cell,
return_sequences=True,
return_state=True,
go_backwards=True,
dtype=tf.float32)
Bi_rnn = tf.keras.layers.Bidirectional(
rnn,
merge_mode='concat'
)
input_embeddings = tf.reshape(input_embeddings, [-1,-1,2])
outputs = Bi_rnn(input_embeddings)
states_fw = [ outputs[i] for i in range(1,self.num_lstm_layers+1)]
states_bw = [ outputs[i] for i in range(self.num_lstm_layers+1,len(outputs))]
for i in range(self.num_lstm_layers):
bi_state_c, bi_state_h = ref_get_bi_state_parts(
states_fw[i], states_bw[i]
)

Get hands-on with 1300+ tech skills courses.