我刚刚开始学习tensorflow,我试图构建一个简单的rnn。以下是重现我遇到的问题所需的所有代码。
tf.reset_default_graph()
rnn = tf.nn.rnn_cell.BasicRNNCell(110,
activation=tf.sigmoid)
x = tf.placeholder(tf.float32, shape=[20, 5, 2], name='x')
xt = tf.transpose(x)
x_split = [x_temp[:,0,:] for x_temp in tf.split(1, 5, xt)[::-1]]
h_list, _ = tf.nn.rnn(rnn, x_split, dtype=tf.float32)
tf.all_variables()[0].get_shape()
# TensorShape([Dimension(130), Dimension(110)])
x_split
# [<tf.Tensor 'Squeeze:0' shape=(2, 20) dtype=float32>,
# <tf.Tensor 'Squeeze_1:0' shape=(2, 20) dtype=float32>,
# <tf.Tensor 'Squeeze_2:0' shape=(2, 20) dtype=float32>,
# <tf.Tensor 'Squeeze_3:0' shape=(2, 20) dtype=float32>,
# <tf.Tensor 'Squeeze_4:0' shape=(2, 20) dtype=float32>]
为什么是矩阵的维数?我希望它是20 x 110,因为输入的维度为20。
tf.__version__
# 0.10.0rc0
答案 0 :(得分:1)
BasicRNNCell
具有以下机制(根据comment):
"""Most basic RNN: output = new_state = activation(W * input + U * state + B)."""
您检查大小的变量RNN/BasicRNNCell/Linear/Matrix:0
是一个内部RNNCell变量,编码从一个状态到状态的转换。因此,它接受大小为20的输入和大小为110的先前状态,并输出大小为110的下一个状态,因此它被编码为130 x 110
矩阵。
换句话说,它会从评论中连接U
和W
。