tensorflow - 为什么BasicRNNCell的矩阵形状不等于(n_hidden x n_input)?

时间:2016-11-12 00:38:05

标签: python-3.x tensorflow

我刚刚开始学习tensorflow,我试图构建一个简单的rnn。以下是重现我遇到的问题所需的所有代码。

tf.reset_default_graph()
rnn = tf.nn.rnn_cell.BasicRNNCell(110, 
                                  activation=tf.sigmoid)
x = tf.placeholder(tf.float32, shape=[20, 5, 2], name='x')
xt = tf.transpose(x)
x_split = [x_temp[:,0,:] for x_temp in tf.split(1, 5, xt)[::-1]]
h_list, _ = tf.nn.rnn(rnn, x_split, dtype=tf.float32)

tf.all_variables()[0].get_shape()
# TensorShape([Dimension(130), Dimension(110)])
x_split
# [<tf.Tensor 'Squeeze:0' shape=(2, 20) dtype=float32>,
#  <tf.Tensor 'Squeeze_1:0' shape=(2, 20) dtype=float32>,
#  <tf.Tensor 'Squeeze_2:0' shape=(2, 20) dtype=float32>,
#  <tf.Tensor 'Squeeze_3:0' shape=(2, 20) dtype=float32>,
#  <tf.Tensor 'Squeeze_4:0' shape=(2, 20) dtype=float32>]

为什么是矩阵的维数?我希望它是20 x 110,因为输入的维度为20。

tf.__version__
# 0.10.0rc0

1 个答案:

答案 0 :(得分:1)

BasicRNNCell具有以下机制(根据comment):

"""Most basic RNN: output = new_state = activation(W * input + U * state + B)."""

您检查大小的变量RNN/BasicRNNCell/Linear/Matrix:0是一个内部RNNCell变量,编码从一个状态到状态的转换。因此,它接受大小为20的输入和大小为110的先前状态,并输出大小为110的下一个状态,因此它被编码为130 x 110矩阵。

换句话说,它会从评论中连接UW