张量流中BasicRNNCell中的input_size是什么?

时间:2017-10-31 11:30:25

标签: machine-learning tensorflow neural-network deep-learning recurrent-neural-network

根据BasicRNNCell的文件:

__call__(
    inputs,
    state,
    scope=None)

Args:
inputs: 2-D tensor with shape [batch_size x input_size].

input_size似乎在不同的运行中有所不同?据我所知,RN input_size确定内部权重矩阵W_x的形状为(input_size, hidden_state_size),它应该是一致的。如果我交替使用input_size=3input_size=4运行此单元格会怎样?

1 个答案:

答案 0 :(得分:0)

inputs是二维张量:[batch_size x input_size]

您是对的,input_size必须与RNN小区的num_units对应。但batch_size可能会有所不同,只需与调用的另一个参数state相对应。

试试这段代码:

import tensorflow as tf
from tensorflow.contrib.rnn import BasicRNNCell

dim = 10
x = tf.placeholder(tf.float32, shape=[None, dim])
y = tf.placeholder(tf.float32, shape=[4, dim])
z = tf.placeholder(tf.float32, shape=[None, dim + 1])
print('x, y, z:', x.shape, y.shape, z.shape)

cell = BasicRNNCell(dim)
state1 = cell.zero_state(batch_size=4, dtype=tf.float32)
state2 = cell.zero_state(batch_size=8, dtype=tf.float32)

out1, out2 = cell(x, state1)
print(out1.shape, out2.shape)

out1, out2 = cell(x, state2)
print(out1.shape, out2.shape)

out1, out2 = cell(y, state1)
print(out1.shape, out2.shape)

这是输出:

x, y, z: (?, 10) (4, 10) (?, 11)
(4, 10) (4, 10)
(8, 10) (8, 10)
(4, 10) (4, 10)

此单元格接受两个状态x y state1,并且不接受z任何状态。以下两个调用都会导致错误:

out1, out2 = cell(y, state2)     # ERROR: dimensions mismatch
print(out1.shape, out2.shape)

out1, out2 = cell(z, state1)     # ERROR: dimensions mismatch
print(out1.shape, out2.shape)
相关问题