def text_rnn(input_text, batch_size=64, reuse=None):
cell = tf.contrib.rnn.GRUCell(n_hidden,
kernel_initializer =
tf.truncated_normal_initializer(stddev=0.0001),
bias_initializer =
tf.truncated_normal_initializer(stddev=0.0001),
reuse=reuse)
output, _ = tf.nn.dynamic_rnn(
cell,
input_text,
dtype=tf.float32,
sequence_length = length(input_text)
)
index = tf.range(0,batch_size)*n_steps + (tf.cast(length(input_text),tf.int32) - 1)
flat = tf.reshape(output,[-1,int(output.get_shape()[2])])
last = tf.gather(flat,index)
return last
第一次调用该函数运行正常,但是下次在同一会话上运行会出错。重新启动会话可以正常工作。
last = text_rnn(input_text)
错误:
ValueError Traceback (most recent call last)
<ipython-input-45-4c61d5c2fe6b> in <module>
13 losses = []
14 step = 0
---> 15 last = text_rnn(input_text)
16 g_loss, d_loss = get_loss(real_image_batch, wrong_image_batch, inputs_noise, last, image_depth, smooth=0.1)
17 g_train_opt, d_train_opt = get_optimizer(g_loss, d_loss, beta1, learning_rate)
<ipython-input-44-e190e8e83384> in text_rnn(input_text, batch_size, reuse)
8 input_text,
9 dtype=tf.float32,
---> 10 sequence_length = length(input_text)
11 )
12
ValueError: Tensor("rnn_1/Const:0", shape=(1,), dtype=int32) must be from the same graph as Tensor("Equal_2:0", shape=(1,), dtype=bool).