Combined State

Combine the final states for a BiLSTM into usable initial states.

Chapter Goals:

  • Combine the final states for each BiLSTM layer

A. LSTMStateTuple initialization

We initialize an LSTMStateTuple object with a hidden state (c) and state output (h).

Below we show an example of initializing an LSTMStateTuple object using the final forward and backward states from a single layer BiLSTM encoder.

Press + to interact
import tensorflow as tf
# Forward state of single-layer BiLSTM final states
fw_c = forward_final_state[0]
fw_h = forward_final_state[1]
# Backward state of single-layer BiLSTM final states
bw_c = backward_final_state[0]
bw_h = backward_final_state[1]
# Concatenate along final axis
final_c = tf.concat([fw_c, bw_c], -1 )
final_h = tf.concat([fw_h, bw_h], -1)
combined_state = tf.compat.v1.nn.rnn_cell.LSTMStateTuple(
final_c, final_h)
print(combined_state)

In the above example, we combined the BiLSTM forward and backward states into a single LSTMStateTuple object, which can be passed into the decoder.

For BiLSTM encoders with multiple layers, we combine the states for each layer to create a tuple of LSTMStateTuple objects. The element at index ii of the tuple is the ithi^{th} layer’s combined final state.

Time to Code!

In this chapter you'll be finishing the for loop of the encoder function. This is on line 65 of the code editor.

For each BiLSTM layer, we create a combined state using the combined c and h properties from the previous chapter.

Inside the for loop, set bi_lstm_state equal to tf.compat.v1.nn.rnn_cell.LSTMStateTuple, initialized with bi_state_c and bi_state_h.

After creating the layer's final state, we append it to the end of combined_state.

Inside the for loop, append bi_lstm_state to the end of combined_state.

Press + to interact
import tensorflow as tf
# Get c and h vectors for bidirectional LSTM final states
def ref_get_bi_state_parts(state_fw, state_bw):
bi_state_c = tf.concat([state_fw[0], state_bw[0]], -1)
bi_state_h = tf.concat([state_fw[1], state_bw[1]], -1)
return bi_state_c, bi_state_h
# Seq2seq model
class Seq2SeqModel(object):
def __init__(self, vocab_size, num_lstm_layers, num_lstm_units):
self.vocab_size = vocab_size
# Extended vocabulary includes start, stop token
self.extended_vocab_size = vocab_size + 2
self.num_lstm_layers = num_lstm_layers
self.num_lstm_units = num_lstm_units
self.tokenizer = tf.keras.preprocessing.text.Tokenizer(
num_words=vocab_size)
def make_lstm_cell(self, dropout_keep_prob, num_units):
cell = tf.keras.layers.LSTMCell(num_units, dropout=dropout_keep_prob)
return cell
# Create multi-layer LSTM
def stacked_lstm_cells(self, is_training, num_units):
dropout_keep_prob = 0.5 if is_training else 1.0
cell_list = [self.make_lstm_cell(dropout_keep_prob, num_units) for i in range(self.num_lstm_layers)]
cell = tf.keras.layers.StackedRNNCells(cell_list)
return cell
# Get embeddings for input/output sequences
def get_embeddings(self, sequences, scope_name):
with tf.compat.v1.variable_scope(scope_name,reuse=tf.compat.v1.AUTO_REUSE):
cat_column = tf.compat.v1.feature_column \
.categorical_column_with_identity(
'sequences', self.extended_vocab_size)
embed_size = int(self.extended_vocab_size**0.25)
embedding_column = tf.compat.v1.feature_column.embedding_column(
cat_column, embed_size)
seq_dict = {'sequences': sequences}
embeddings= tf.compat.v1.feature_column \
.input_layer(
seq_dict, [embedding_column])
sequence_lengths = tf.compat.v1.placeholder("int64", shape=(None,), name=scope_name+"/sinput_layer/sequence_length")
return embeddings, tf.cast(sequence_lengths, tf.int32)
# Create the encoder for the model
def encoder(self, encoder_inputs, is_training):
input_embeddings, input_seq_lens = self.get_embeddings(encoder_inputs, 'encoder_emb')
cell = self.stacked_lstm_cells(is_training, self.num_lstm_units)
combined_state = []
rnn = tf.keras.layers.RNN(
cell,
return_sequences=True,
return_state=True,
go_backwards=True,
dtype=tf.float32)
Bi_rnn = tf.keras.layers.Bidirectional(
rnn,
merge_mode='concat'
)
input_embeddings = tf.reshape(input_embeddings, [-1,-1,2])
outputs = Bi_rnn(input_embeddings)
enc_outputs = outputs[0]
states_fw = [ outputs[i] for i in range(1,self.num_lstm_layers+1)]
states_bw = [ outputs[i] for i in range(self.num_lstm_layers+1,len(outputs))]
for i in range(self.num_lstm_layers):
bi_state_c, bi_state_h = ref_get_bi_state_parts(
states_fw[i], states_bw[i]
)
#CODE HERE
final_state = tuple(combined_state)
return enc_outputs, input_seq_lens, final_state

Get hands-on with 1300+ tech skills courses.