如何从tf.data.Dataset.zip((images,labels))获取两个tf.dataset

时间:2018-10-21 17:38:22

标签: python tensorflow mnist

我正在研究Python / tensorflow / mnist教程。

自从使用tensorflow网站的原始代码几个星期以来,我得到警告,图像数据集将很快被弃用,我应该使用以下代码: https://github.com/tensorflow/models/blob/master/official/mnist/dataset.py

我使用以下代码将其加载:

from tensorflow.models.official.mnist import dataset
trainfile = dataset.train(data_dir)

哪个返回:

tf.data.Dataset.zip((images, labels))

问题是我找不到以以下方式将它们分开的方法:

  trainfile = dataset.train(data_dir)
  train_data= trainfile.images
  train_label= trainfile.label

但这显然不起作用,因为属性图像和标签不存在。 trainfile是一个tf.dataset。

知道tf.dataset是由int32和float32组成的,我尝试过:

  train_data = trainfile.map(lambda x,y : x.dtype == tf.float32)

但是它返回并清空数据集。

我坚持这样做(但会被保留),因为这是本教程的工作方式:

https://www.tensorflow.org/tutorials/estimators/cnn

我看到了很多从数据集中获取元素的解决方案,但是从以下代码中完成的zip操作中没有任何回头路可走

tf.data.Dataset.zip((images, labels))

在此先感谢您的帮助。

3 个答案:

答案 0 :(得分:1)

我希望这会有所帮助:

inputs = tf.placeholder(tf.float32, shape=(None, 784), name='inputs')
outputs = tf.placeholder(tf.float32, shape=(None,), name='outputs')

#Prepare a tensorflow dataset
ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))

ds = ds.shuffle(buffer_size=10, reshuffle_each_iteration=True).batch(batch_size=batch_size, drop_remainder=True).repeat()
iter = ds.make_one_shot_iterator()
next = iter.get_next()

inputs = next[0]
outputs = next[1]

答案 1 :(得分:0)

最好将单个迭代器返回图像和标签,而不是将它们分为两个数据集,一个用于图像,另一个用于标签。

之所以选择此示例,是因为即使经过一系列复杂的混洗,重新排序,过滤等操作(就像在非平凡的输入管道中一样),也可以更加轻松地确保将每个示例与其标签匹配。 / p>

答案 2 :(得分:0)

您可以可视化图像并找到其相关标签

ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))

ds = ds.shuffle(buffer_size=10).batch(batch_size=batch_size)
iter = ds.make_one_shot_iterator()
next = iter.get_next()

def display(image, label):
# display image
   ...
   plt.imshow(image)
   ...

with tf.Session() as sess:
    try:
        while True:
             image, label = sess.run(next) 
             # image = numpy array (batch, image_size)
             # label = numpy array (batch, label)
        display(image[0], label[0]) #display first image in batch
    except:
        pass