根据BasicRNNCell
的文件:
__call__(
inputs,
state,
scope=None)
Args:
inputs: 2-D tensor with shape [batch_size x input_size].
input_size
似乎在不同的运行中有所不同?据我所知,RN input_size
确定内部权重矩阵W_x的形状为(input_size, hidden_state_size)
,它应该是一致的。如果我交替使用input_size=3
和input_size=4
运行此单元格会怎样?
答案 0 :(得分:0)
inputs
是二维张量:[batch_size x input_size]
。
您是对的,input_size
必须与RNN小区的num_units
对应。但batch_size
可能会有所不同,只需与调用的另一个参数state
相对应。
试试这段代码:
import tensorflow as tf
from tensorflow.contrib.rnn import BasicRNNCell
dim = 10
x = tf.placeholder(tf.float32, shape=[None, dim])
y = tf.placeholder(tf.float32, shape=[4, dim])
z = tf.placeholder(tf.float32, shape=[None, dim + 1])
print('x, y, z:', x.shape, y.shape, z.shape)
cell = BasicRNNCell(dim)
state1 = cell.zero_state(batch_size=4, dtype=tf.float32)
state2 = cell.zero_state(batch_size=8, dtype=tf.float32)
out1, out2 = cell(x, state1)
print(out1.shape, out2.shape)
out1, out2 = cell(x, state2)
print(out1.shape, out2.shape)
out1, out2 = cell(y, state1)
print(out1.shape, out2.shape)
这是输出:
x, y, z: (?, 10) (4, 10) (?, 11)
(4, 10) (4, 10)
(8, 10) (8, 10)
(4, 10) (4, 10)
此单元格接受两个状态x
y
state1
,并且不接受z
任何状态。以下两个调用都会导致错误:
out1, out2 = cell(y, state2) # ERROR: dimensions mismatch
print(out1.shape, out2.shape)
out1, out2 = cell(z, state1) # ERROR: dimensions mismatch
print(out1.shape, out2.shape)