值错误:Tensor(“ rnn / Const:0”,shape =(1,),dtype = int32)必须与Tensor(“ Equal_1:0”,shape =(1,),dtype = bool)来自同一张图

时间:2019-05-07 17:06:58

标签: python tensorflow

def text_rnn(input_text, batch_size=64, reuse=None):
    cell = tf.contrib.rnn.GRUCell(n_hidden,
                          kernel_initializer = 
    tf.truncated_normal_initializer(stddev=0.0001),
                          bias_initializer = 
    tf.truncated_normal_initializer(stddev=0.0001),
                          reuse=reuse)
    output, _ = tf.nn.dynamic_rnn(
                          cell,
                          input_text,
                          dtype=tf.float32,
                          sequence_length = length(input_text)
                          )

    index = tf.range(0,batch_size)*n_steps + (tf.cast(length(input_text),tf.int32) - 1)
    flat = tf.reshape(output,[-1,int(output.get_shape()[2])])
    last = tf.gather(flat,index)
    return last

第一次调用该函数运行正常,但是下次在同一会话上运行会出错。重新启动会话可以正常工作。

last = text_rnn(input_text)

错误:

    ValueError                                Traceback (most recent call last)
    <ipython-input-45-4c61d5c2fe6b> in <module>
     13 losses = []
     14 step = 0
     ---> 15 last = text_rnn(input_text)
     16 g_loss, d_loss = get_loss(real_image_batch, wrong_image_batch, inputs_noise, last, image_depth, smooth=0.1)
     17 g_train_opt, d_train_opt = get_optimizer(g_loss, d_loss, beta1, learning_rate)

    <ipython-input-44-e190e8e83384> in text_rnn(input_text, batch_size, reuse)
      8                                   input_text,
      9                                   dtype=tf.float32, 
     ---> 10                                   sequence_length = length(input_text)
     11                                   )
     12 
     ValueError: Tensor("rnn_1/Const:0", shape=(1,), dtype=int32) must be from the same graph as Tensor("Equal_2:0", shape=(1,), dtype=bool).

0 个答案:

没有答案