如何分批提供数据TensorFlow CNN?

时间:2017-09-12 06:28:30

标签: tensorflow neural-network

github或其他博客上的几乎所有示例都使用mnist数据集进行演示。当我尝试为我的图像数据使用相同的深度NN时,我遇到以下问题。

他们使用:

  batch_x, batch_y = mnist.train.next_batch(batch_size)
  # Run optimization op (backprop)
  sess.run(train_op, feed_dict={X: trainimg, Y: trainlabel, keep_prob: 0.8})

next_batch方法分批提供数据。

我的问题是:

我们是否有任何类似的方法批量提供数据?

1 个答案:

答案 0 :(得分:3)

你应该看看tf.contrib.data.Dataset。您可以创建输入管道:定义源,应用转换并批量处理。有关导入数据,请参阅programmer's guide

来自文档:

  

数据集API使您能够从简单,可重复使用的部分构建复杂的输入管道。例如,图像模型的管道可能会聚合分布式文件系统中的文件数据,对每个图像应用随机扰动,并将随机选择的图像合并为一批进行培训

编辑:

我猜你所拥有的是一系列图片(文件名)。以下是程序员指南中的一个示例。

根据您的输入文件,转换部分将更改。以下是使用图片文件数组的摘录。

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_image(image_string)
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

# A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])

# labels[i] is the label for the image in filenames[i].
labels = tf.constant([0, 37, ...])

dataset = tf.contrib.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)

# Now you have a dataset of (image, label). Basically kind of a list with
# all your pictures encoded along with a label. 
# Batch it.
dataset = dataset.batch(32)
# Create an iterator.
iterator = dataset.make_one_shot_iterator()
# Retrieve the next element.
image_batch, label_batch = iterator.get_next()

您也可以随机播放图像。

现在,您可以在模型定义中使用image_batchlabel_batch作为占位符。