如何在Tensorflow输入管道中堆叠通道?

时间:2017-08-14 15:48:57

标签: python file-io tensorflow queue

几周前我开始使用tf,现在正在努力处理输入队列。 我想要做的是以下内容:我有一个包含477个时间灰度图像的文件夹。现在我想要,例如获取前3个图像并将它们堆叠在一起(=> 600,600,3),这样我得到一个包含3个通道的示例。接下来我想拍摄第四张图像并将其用作标签(仅1通道=> 600,600,1)。然后我想将它们传递给tf.train.batch并创建批次。

我想我找到了一个解决方案,请参阅下面的代码。但我想知道是否有更时尚的解决方案。

我的实际问题是:队列结束时会发生什么。因为我总是从队列中挑选4个图像(3个用于输入,1个用于标签)并且我的队列中有477个图像,所以事情无法解决。然后再次填充我的队列并继续(所以如果队列中剩下1个图像,它会拍摄此图像,再次填满队列并再拍摄2张图像以获得所需的3张图像?)。或者如果我想要一个合适的解决方案,我是否需要在我的文件夹中将4个可被4整除的图像?

def read_image(filename_queue):
  reader = tf.WholeFileReader()
  _, value = reader.read(filename_queue)
  image = tf.image.decode_png(value, dtype=tf.uint8)
  image = tf.cast(image, tf.float32)
  image = tf.image.resize_images(image, [600, 600])
  return image

def input_pipeline(file_names, batch_size, num_epochs=None):

  filename_queue = tf.train.string_input_producer(file_names, num_epochs=num_epochs, shuffle=False)
  image1 = read_image(filename_queue)
  image2 = read_image(filename_queue)
  image3 = read_image(filename_queue)
  image = tf.concat([image1, image2, image3,], axis=2)
  label = read.image(filename_queue)

  # Reshape is necessary, otherwise I get an error..
  image = tf.reshape(image, [600, 600, 3])
  label = tf.reshape(label, [600, 600, 1])

  min_after_dequeue = 200
  capacity = min_after_dequeue + 12 * batch_size
  image_batch, label_batch = tf.train.batch([image, label],
                             batch_size=batch_size,
                             num_threads=12,
                             capacity=capacity)
  return image_batch, label_batch

感谢您的帮助!

1 个答案:

答案 0 :(得分:0)

  

但我想知道是否有更时尚的解决方案

是的!有一个更好,更快的解决方案。首先,您需要重新设计数据库,因为您希望将3个灰色图像组合成1个rgb图像进行训练;从灰色图像准备RGB图像的数据集(它将在训练期间节省大量时间)。

  

重新设计检索数据的方式

  # retrieve image and corresponding label at the same time 
  # here if you set the num_epochs=None, the queue will run continuously; and it will take-care of the data need for training till end
  filename_queue = tf.train.string_input_producer([file_names_images_list, corresponding_file_names_label_list], num_epochs=None, shuffle=False)

  image = read_image(filename_queue[0])
  label = read_image(filename_queue[1])