Combined State
Combine the final states for a BiLSTM into usable initial states.
We'll cover the following
Chapter Goals:
- Combine the final states for each BiLSTM layer
A. LSTMStateTuple initialization
We initialize an LSTMStateTuple
object with a hidden state (c
) and state output (h
).
Below we show an example of initializing an LSTMStateTuple
object using the final forward and backward states from a single layer BiLSTM encoder.
import tensorflow as tf# Forward state of single-layer BiLSTM final statesfw_c = forward_final_state[0]fw_h = forward_final_state[1]# Backward state of single-layer BiLSTM final statesbw_c = backward_final_state[0]bw_h = backward_final_state[1]# Concatenate along final axisfinal_c = tf.concat([fw_c, bw_c], -1 )final_h = tf.concat([fw_h, bw_h], -1)combined_state = tf.compat.v1.nn.rnn_cell.LSTMStateTuple(final_c, final_h)print(combined_state)
In the above example, we combined the BiLSTM forward and backward states into a single LSTMStateTuple
object, which can be passed into the decoder.
LSTMStateTuple
objects. The element at index of the tuple is the layer’s combined final state.
Time to Code!
In this chapter you'll be finishing the for
loop of the encoder
function. This is on line 65 of the code editor.
For each BiLSTM layer, we create a combined state using the combined c
and h
properties from the previous chapter.
Inside the for
loop, set bi_lstm_state
equal to tf.compat.v1.nn.rnn_cell.LSTMStateTuple
, initialized with bi_state_c
and bi_state_h
.
After creating the layer's final state, we append it to the end of combined_state
.
Inside the for
loop, append bi_lstm_state
to the end of combined_state
.
import tensorflow as tf# Get c and h vectors for bidirectional LSTM final statesdef ref_get_bi_state_parts(state_fw, state_bw):bi_state_c = tf.concat([state_fw[0], state_bw[0]], -1)bi_state_h = tf.concat([state_fw[1], state_bw[1]], -1)return bi_state_c, bi_state_h# Seq2seq modelclass Seq2SeqModel(object):def __init__(self, vocab_size, num_lstm_layers, num_lstm_units):self.vocab_size = vocab_size# Extended vocabulary includes start, stop tokenself.extended_vocab_size = vocab_size + 2self.num_lstm_layers = num_lstm_layersself.num_lstm_units = num_lstm_unitsself.tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=vocab_size)def make_lstm_cell(self, dropout_keep_prob, num_units):cell = tf.keras.layers.LSTMCell(num_units, dropout=dropout_keep_prob)return cell# Create multi-layer LSTMdef stacked_lstm_cells(self, is_training, num_units):dropout_keep_prob = 0.5 if is_training else 1.0cell_list = [self.make_lstm_cell(dropout_keep_prob, num_units) for i in range(self.num_lstm_layers)]cell = tf.keras.layers.StackedRNNCells(cell_list)return cell# Get embeddings for input/output sequencesdef get_embeddings(self, sequences, scope_name):with tf.compat.v1.variable_scope(scope_name,reuse=tf.compat.v1.AUTO_REUSE):cat_column = tf.compat.v1.feature_column \.categorical_column_with_identity('sequences', self.extended_vocab_size)embed_size = int(self.extended_vocab_size**0.25)embedding_column = tf.compat.v1.feature_column.embedding_column(cat_column, embed_size)seq_dict = {'sequences': sequences}embeddings= tf.compat.v1.feature_column \.input_layer(seq_dict, [embedding_column])sequence_lengths = tf.compat.v1.placeholder("int64", shape=(None,), name=scope_name+"/sinput_layer/sequence_length")return embeddings, tf.cast(sequence_lengths, tf.int32)# Create the encoder for the modeldef encoder(self, encoder_inputs, is_training):input_embeddings, input_seq_lens = self.get_embeddings(encoder_inputs, 'encoder_emb')cell = self.stacked_lstm_cells(is_training, self.num_lstm_units)combined_state = []rnn = tf.keras.layers.RNN(cell,return_sequences=True,return_state=True,go_backwards=True,dtype=tf.float32)Bi_rnn = tf.keras.layers.Bidirectional(rnn,merge_mode='concat')input_embeddings = tf.reshape(input_embeddings, [-1,-1,2])outputs = Bi_rnn(input_embeddings)enc_outputs = outputs[0]states_fw = [ outputs[i] for i in range(1,self.num_lstm_layers+1)]states_bw = [ outputs[i] for i in range(self.num_lstm_layers+1,len(outputs))]for i in range(self.num_lstm_layers):bi_state_c, bi_state_h = ref_get_bi_state_parts(states_fw[i], states_bw[i])#CODE HEREfinal_state = tuple(combined_state)return enc_outputs, input_seq_lens, final_state
Get hands-on with 1300+ tech skills courses.