使用tf.train.batch时出现死锁

时间:2018-03-27 06:14:37

标签: tensorflow

以下代码旨在加载具有张量流的一对图像。然而,它处于僵局。特别是,在我添加了tf.train.batch部分之后。如果我在tf.train.batch之前得到了值,它就可以了。

你能指出哪一部分不正确吗?

import tensorflow as tf

batch_size = 1
alist = [['a.jpg', 'b.jpg']] * 1000

logdir = './logdir'
NUM_THREADS = 5

with tf.Graph().as_default():

    init = tf.constant(0, dtype=tf.int64)
    global_step = tf.get_variable(name='global_step', trainable=False, initializer=init)

    input_queue = tf.FIFOQueue(50, dtypes=[tf.string, tf.string], shapes=[[], []])
    input_enqueue_op = input_queue.enqueue_many([alist[:, 0], alist[:, 1]])
    input_dir, target_dir = input_queue.dequeue()

    input_value = tf.read_file(input_dir)
    input_img = tf.image.decode_jpeg(input_value,  channels=3)
    target_value = tf.read_file(target_dir) 
    target_img = tf.image.decode_jpeg(target_value,  channels=3)

    input_img = tf.image.resize_images(input_img, [224, 224])
    input_img.set_shape((224, 224, 3))
    input_img = tf.image.per_image_standardization(input_img)

    target_img = tf.image.resize_images(target_img, [224, 224])
    target_img.set_shape((224, 224, 3))
    target_img = tf.image.per_image_standardization(target_img)

    img_batch, gt_img_batch = tf.train.batch(
        [input_img, target_img],
        batch_size = 1,
        num_threads = 1, 
        # shapes= [input_img.get_shape(), target_img.get_shape()],
        capacity = 30,
        enqueue_many=False,
        allow_smaller_final_batch=True,
        name='input_batch')


    qr = tf.train.QueueRunner(input_queue, [input_enqueue_op] * NUM_THREADS)
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = qr.create_threads(sess, coord=coord, start=True)

        for i in range(10):
            a = sess.run(img_batch)
            print(a.shape)

        # Wait for threads to finish.
        coord.request_stop()
        coord.join(threads)

1 个答案:

答案 0 :(得分:1)

tf.train.batch创建自己的队列运行器:

  

此功能使用队列实现。队列的QueueRunner被添加到当前Graph的QUEUE_RUNNER集合中。

他们也需要开始。 TensoFlow有一个函数可以启动图中收集的所有队列运行器:tf.train.start_queue_runners

使用tf.train.add_queue_runner将队列运行器添加到相应的集合也是有意义的。这样start_queue_runners也将启动队列运行。