Tensorflow:确定MultiRNN中的最终状态

时间:2018-08-10 01:11:36

标签: python tensorflow rnn

我是TF的新手,我正在尝试在NN中实现多个GRU单元。但是,我无法确定MultiRNN单元的最终状态。

例如,当我使用以下代码时:

num_units = [128, 128]
tf.reset_default_graph()
x = tf.placeholder(tf.int32, [None, 134])
y = tf.placeholder(tf.int32, [None]) 
embedding_matrix = tf.Variable(tf.random_uniform([153, 128], -1.0, 1.0))
embeddings = tf.nn.embedding_lookup(embedding_matrix, x) 
cells = [tf.contrib.rnn.GRUCell(num_units=n) for n in num_units]
cell_type = tf.contrib.rnn.MultiRNNCell(cells=cells, state_is_tuple=True)
cell_type = tf.contrib.rnn.DropoutWrapper(cell=cell_type, output_keep_prob=0.75)
_, (encoding, _) = tf.nn.dynamic_rnn(cell_type, embeddings, dtype=tf.float32)

最后一行代码的输出是:

(<tf.Tensor 'rnn/transpose_1:0' shape=(?, 134, 128) dtype=float32>, (<tf.Tensor 'rnn/while/Exit_3:0' shape=(?, 128) dtype=float32>, <tf.Tensor 'rnn/while/Exit_4:0' shape=(?, 128) dtype=float32>))

我相信格式是:

  

输出格式:(a,[b,c])

文档说输出采用以下格式(输出,状态= [batch_size,cell.state_size])。但是,我无法确定其中哪一个是此存储单元的最终状态。我认为应该是b。

此外,当我在4个GRU单元上运行相同的代码时:

num_units = [128, 128, 128, 128]

输出结果更加令人困惑:

  

输出格式 :( a,[b,c,d,e])

我对以上哪一个是最终的存储状态感到困惑,然后可以进一步处理该状态以进行损耗计算和做出预测。

1 个答案:

答案 0 :(得分:0)

好的。我从here找到了答案,它是MultiRNN堆栈中最后一个单元的状态。以下代码将根据输入到num_units中的尺寸轻松提取适当的代码:

num_units = [128, 128, 128, 128]
rnn_output, final_states = tf.nn.dynamic_rnn(cell_type, embeddings, dtype=tf.float32)
encoding = final_states[len(num_units)-1] # For GRU & RNN
encoding = final_states[len(num_units)-1][0] # For LSTM