我正在构建一个rnn,我使用tf.nn.dynamic_rnn来产生输出和状态。 代码如下(tf版本1.3):
import tensorflow as tf
def lstm_cell():
return tf.contrib.rnn.DropoutWrapper(tf.contrib.rnn.BasicLSTMCell(128), output_keep_prob=0.7)
cell= tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(3)])
initial_state= cell.zero_state(1, tf.float32)
layer = tf.placeholder(tf.float32, [1,1,36])
outputs, state=tf.nn.dynamic_rnn(cell=cell, inputs=layer, initial_state=initial_state)
由于输入张量始终为批量大小= 1,因此initial_state和state的批量大小为1。 layer也是batch_size = 1的输入,并且对于每个单元,有36个节点(嵌入序列的大小)。每个层都有lstm_size 128。
当我循环rnn单元时出现问题。
rnn_outputs_sequence=outputs
for i in range(1, num_pics, 1):
outputs, state=tf.nn.dynamic_rnn(cell=cell, inputs=outputs, initial_state=state)
rnn_outputs_sequence=tf.concat((rnn_outputs_sequence, outputs),axis=1)
预计rnn_outputs_sequence的形状为[1,num_pics,36]。但是,这会触发错误:
Trying to share variable rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel, but specified shape (256, 512) and found shape (164, 512).
我无法弄清楚这种形状[164,512]。 任何人都可以帮我解决这个问题吗? 感谢。
答案 0 :(得分:0)
import tensorflow as tf
def lstm_cell():
return tf.contrib.rnn.DropoutWrapper(tf.contrib.rnn.BasicLSTMCell(128), output_keep_prob=0.7)
cell= tf.contrib.rnn.MultiRNNCell([lstm_cell() for _ in range(2)])
initial_state= cell.zero_state(1, tf.float32)
layer = tf.placeholder(tf.float32, [1,1,36])
outputs, state=tf.nn.dynamic_rnn(cell=cell, inputs=layer, initial_state=initial_state)
outputs = tf.reshape(outputs, shape=[1, -1])
outputs = tf.layers.dense(outputs, 36,\
kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=False))
outputs = tf.reshape(outputs, shape=[1, 1, -1])
rnn_outputs_sequence=outputs
print(outputs)
for i in range(1, 16, 1):
outputs, state=tf.nn.dynamic_rnn(cell=cell, inputs=outputs, initial_state=state)
outputs = tf.reshape(outputs, shape=[1, -1])
outputs = tf.layers.dense(outputs, 36,\
kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=False))
outputs = tf.reshape(outputs, shape=[1, 1, -1])
print(outputs)
rnn_outputs_sequence=tf.concat((rnn_outputs_sequence, outputs),axis=1)
print(rnn_outputs_sequence)