如何在TensorFlow中访问数据集的功能词典

时间:2019-06-25 17:28:12

标签: python-3.x tensorflow

我使用tensorflow数据集将MNIST数据集集成到Tensorflow中,现在想用Matplotlib可视化单个图像。我是按照以下指南操作的:https://www.tensorflow.org/datasets/overview

不幸的是,我在执行过程中收到一条错误消息。但它在《指南》中效果很好。

根据指南,您必须使用take()函数创建仅包含一张图像的新数据集。然后,可以在指南中访问功能。在尝试过程中,我总是收到错误消息。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v1 as tf

import tensorflow_datasets as tfds



mnist_train, info = tfds.load(name="mnist", split=tfds.Split.TRAIN, with_info=True)
assert isinstance(mnist_train, tf.data.Dataset)

mnist_example = mnist_train.take(50)

#The error is raised in the next line. 
image = mnist_example["image"]
label = mnist_example["label"]

plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap("gray"))
print("Label: %d" % label.numpy())

这是错误消息:

Traceback (most recent call last):
  File "D:/mnist/model.py", line 24, in <module>
    image = mnist_example["image"]
TypeError: 'DatasetV1Adapter' object is not subscriptable

有人知道我该如何解决吗?经过大量研究,我仍然没有找到解决方案。

1 个答案:

答案 0 :(得分:0)

渴望执行

首先编写代码 tf.enable_eager_execution()

为什么?

因为如果您不这样做,则需要创建图形并执行session.run()以获取一些样本

急切的执行定义(reference):

  

TensorFlow急切的执行是一个命令式编程环境,该环境可以立即评估>操作,而无需构建图:操作返回具体值>而不是构造要稍后运行的计算图

然后

如何访问数据集对象中的样本

您需要的是遍历DatasetV1Adapter对象

通过转换为numpy来访问某些样本的几种方法:

1。

mnist_example = mnist_train.take(50)
for sample in mnist_example:
    image, label = sample["image"].numpy(), sample["label"].numpy()
    plt.imshow(image[:, :, 0].astype(np.uint8), cmap=plt.get_cmap("gray"))
    plt.show()
    print("Label: %d" % label)

2。

mnist_example = tfds.as_numpy(mnist_example, graph=None)
for sample in mnist_example:
    image, label = sample["image"], sample["label"]
    plt.imshow(image[:, :, 0].astype(np.uint8), cmap=plt.get_cmap("gray"))
    plt.show()
    print("Label: %d" % label)

注意1::如果希望将所有50个样本都放在numpy数组中,则可以创建一个空数组,例如np.zeros((28, 28, 50), dtype=np.uint8)数组,并将这些图像分配给它的元素。

>

注2::出于即时展示的目的,请勿转换为np.float32,它无用,图像采用uint8格式/范围(默认情况下未标准化)