使用输入队列的Tensorflow训练卡住了

时间:2016-02-08 16:25:15

标签: python multithreading queue tensorflow

我正在尝试构建类似于this教程中的NN培训。

我的代码如下:

def train():
    init_op = tf.initialize_all_variables()
    sess = tf.Session()
    sess.run(init_op)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    step = 0

    try:
        while not coord.should_stop():
            step += 1
            print 'Training step %i' % step
            training = train_op()
            sess.run(training)

    except tf.errors.OutOfRangeError:
        print 'Done training - epoch limit reached.'
    finally:
        coord.request_stop()

    coord.join(threads)
    sess.close()

MIN_NUM_EXAMPLES_IN_QUEUE = 10
NUM_PRODUCING_THREADS = 1
NUM_CONSUMING_THREADS = 1

def train_op():
    images, true_labels = inputs()
    predictions = NET(images)
    true_labels = tf.cast(true_labels, tf.float32)
    loss = tf.nn.softmax_cross_entropy_with_logits(predictions, true_labels)
    return OPTIMIZER.minimize(loss)


def inputs():
    filenames = [os.path.join(FLAGS.train_dir, filename) 
        for filename in os.listdir(FLAGS.train_dir) 
        if os.path.isfile(os.path.join(FLAGS.train_dir, filename))]
    filename_queue = tf.train.string_input_producer(filenames,
        num_epochs=FLAGS.training_epochs, shuffle=True)

    example_list = [_read_and_preprocess_image(filename_queue) 
        for _ in xrange(NUM_CONSUMING_THREADS)]

    image_batch, label_batch = tf.train.shuffle_batch_join(
        example_list,
        batch_size=FLAGS.batch_size,
        capacity=MIN_NUM_EXAMPLES_IN_QUEUE + (NUM_CONSUMING_THREADS + 2) * FLAGS.batch_size,
        min_after_dequeue=MIN_NUM_EXAMPLES_IN_QUEUE)

    return image_batch, label_batch

教程说

  

这些要求您在运行任何培训或推理步骤之前致电tf.train.start_queue_runners,否则它将永远挂起。

。我正在呼叫tf.train.start_queue_runners,但train()的执行仍然会在第一次出现sess.run(training)时停滞不前。

有人知道我做错了吗?

1 个答案:

答案 0 :(得分:4)

每次尝试运行训练循环时,您都在重新定义网络。

请记住,TensorFlow定义了一个执行图,然后执行它。您想在运行循环之外调用train_op(),并且需要在调用initialize_all_variablestf.train.start_queue_runners之前定义该图表