首次运行后,读取Tensorflow记录文件不起作用

时间:2016-02-17 14:24:34

标签: tensorflow

我有一小段代码从一些TFRecord文件中读取数据。如果我从ipython笔记本运行代码,它在第一次执行块时工作正常。但是,如果我尝试再次执行它而不重新启动内核,则相同的代码块会产生错误(错误:StatusNotOK:未找到:FetchOutputs节点DecodeRaw_2:0:未找到)。代码如下所示。我是否需要关闭/清除/重新初始化某些内容才能让代码多次正常运行?

filename_queue = tf.train.string_input_producer(filename_list)
init = tf.initialize_all_variables()
image = []
label = []
with tf.Session() as sess:
    sess.run(init)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    tf_image, tf_label = read_and_decode(filename_queue)
    for i in range(len(filename_list)):
        image.append(sess.run(tf_image))
        label.append(sess.run(tf_label))

    coord.request_stop()
    coord.join(threads)

请注意,read_and_decode()取自here

2 个答案:

答案 0 :(得分:2)

问题中的代码存在一些问题。

  1. 第一个,as Yaroslav pointed out是所有操作都添加到同一个图表中。这意味着当您调用tf.train.start_queue_runners()(或运行tf.initialize_all_variables()操作符号)时,会话将执行的工作量与您运行此代码段的次数成正比。您可以在对此代码的调用之间调用tf.reset_default_graph(),但是一种更清晰的隔离方式可能是每次都声明一个单独的图:

    with tf.Graph().as_default():  # Declares a new graph for the life of the block.
        filename_queue = tf.train.string_input_producer(filename_list)
        init = tf.initialize_all_variables()
        image = []
        label = []
        with tf.Session() as sess:
            # ...
            coord.join(threads)
    
  2. 第二个问题是对sess.run(tf_image)sess.run(tf_label)的单独调用意味着图像和标签之间的关联丢失了。当您致电sess.run(tf_image)时,您会从阅读器中使用图片标签,但会丢弃标签(反之亦然sess.run(tf_label)。正确的解决方法是将其同时提取同样的步骤:

    image_val, label_val = sess.run([tf_image, tf_label])
    image.append(image_val)
    label.append(label_val)
    
  3. 最后一个问题 - 即使您重置图表也可能导致问题 - 是代码在调用tf.train.start_queue_runners()后将节点添加到图表中。 TensorFlow图上可能存在数据争用,因为read_and_decode()会向图中添加节点,而并行队列运行程序会同时读取它,而tf.Graph不是线程安全的。

    处理此问题的最佳方法是在启动队列运行程序之前定义所有图形:

    with tf.Graph().as_default():
        filename_queue = tf.train.string_input_producer(filename_list)
        image = []
        label = []
        tf_image, tf_label = read_and_decode(filename_queue)
    
        with tf.Session() as sess:
           coord = tf.train.Coordinator()
           threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
           for i in range(len(filename_list)):
               image_val, label_val = sess.run([tf_image, tf_label])
               image.append(image_val)
               label.append(label_val)
    
           coord.request_stop()
           coord.join(threads)
    

答案 1 :(得分:1)

默认为tf。命令附加到具有新名称的默认图形。您可以在第二次运行代码段之前使用tf.reset_default_graph()来清除默认图表。