TensorFlow读取和解码图像的BATCH

时间:2018-05-28 08:38:56

标签: python tensorflow

使用tf.train.string_input_producertf.image.decode_jpeg我设法从磁盘读取并解码单个图像。

这是代码:

# -------- Graph
filename_queue = tf.train.string_input_producer(
    [img_path, img_path])

image_reader = tf.WholeFileReader()

key, image_file = image_reader.read(filename_queue)

image = tf.image.decode_jpeg(image_file, channels=3)

# Run my network
logits = network.get_logits(image)

# -------- Session
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

logits_output = sess.run(logits)

问题是,当我查看logit_outputs的形状时,即使队列长度为2张,我也只得到1个值。

如何读取和解码整个队列?

1 个答案:

答案 0 :(得分:2)

tf.WholeFileReader()tf.train.string_input_producer()沿着N作为迭代器,因此没有一种简单的方法来评估它正在处理的完整数据集的大小。

要从中获取批量的def _parse_function(filename): image_string = tf.read_file(filename) image_decoded = tf.image.decode_image(image_string) return image_decoded # A vector of filenames. filenames = tf.constant([img_path, img_path]) dataset = tf.data.Dataset.from_tensor_slices((filenames)) dataset = dataset.map(_parse_function).batch(N) iterator = dataset.make_one_shot_iterator() next_image_batch = iterator.get_next() logits = network.get_logits(next_image_batch) # ... 个样本,您可以使用image_reader.read_up_to(filename_queue, N)

注意:您可以使用较新的tf.data管道实现相同的目标:

matrix := mat.NewDense(2, 2, []float64{0, 0, 0, 3})