以下代码旨在加载具有张量流的一对图像。然而,它处于僵局。特别是,在我添加了tf.train.batch部分之后。如果我在tf.train.batch之前得到了值,它就可以了。
你能指出哪一部分不正确吗?
import tensorflow as tf
batch_size = 1
alist = [['a.jpg', 'b.jpg']] * 1000
logdir = './logdir'
NUM_THREADS = 5
with tf.Graph().as_default():
init = tf.constant(0, dtype=tf.int64)
global_step = tf.get_variable(name='global_step', trainable=False, initializer=init)
input_queue = tf.FIFOQueue(50, dtypes=[tf.string, tf.string], shapes=[[], []])
input_enqueue_op = input_queue.enqueue_many([alist[:, 0], alist[:, 1]])
input_dir, target_dir = input_queue.dequeue()
input_value = tf.read_file(input_dir)
input_img = tf.image.decode_jpeg(input_value, channels=3)
target_value = tf.read_file(target_dir)
target_img = tf.image.decode_jpeg(target_value, channels=3)
input_img = tf.image.resize_images(input_img, [224, 224])
input_img.set_shape((224, 224, 3))
input_img = tf.image.per_image_standardization(input_img)
target_img = tf.image.resize_images(target_img, [224, 224])
target_img.set_shape((224, 224, 3))
target_img = tf.image.per_image_standardization(target_img)
img_batch, gt_img_batch = tf.train.batch(
[input_img, target_img],
batch_size = 1,
num_threads = 1,
# shapes= [input_img.get_shape(), target_img.get_shape()],
capacity = 30,
enqueue_many=False,
allow_smaller_final_batch=True,
name='input_batch')
qr = tf.train.QueueRunner(input_queue, [input_enqueue_op] * NUM_THREADS)
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = qr.create_threads(sess, coord=coord, start=True)
for i in range(10):
a = sess.run(img_batch)
print(a.shape)
# Wait for threads to finish.
coord.request_stop()
coord.join(threads)
答案 0 :(得分:1)
tf.train.batch
创建自己的队列运行器:
此功能使用队列实现。队列的QueueRunner被添加到当前Graph的QUEUE_RUNNER集合中。
他们也需要开始。 TensoFlow有一个函数可以启动图中收集的所有队列运行器:tf.train.start_queue_runners
。
使用tf.train.add_queue_runner
将队列运行器添加到相应的集合也是有意义的。这样start_queue_runners
也将启动队列运行。