如何在Tensorflow 2.x中正确操作tfds.load()数据集?

时间:2019-11-03 19:06:39

标签: python tensorflow tensorflow-datasets tensorflow2.0

我正在学习如何从Udemy课程中在tensorflow 2.0和Keras中从头开始创建MNIST模型。

所以,我得到了mnist数据集,如下所示

mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']

一切都很好,即使我测试模型的准确性达到97%,我也很高兴。

当我尝试做一些不同于课程的事情时,问题就开始了。我尝试使用matplotlib plt.imshow()从mnist_dataset打印一些示例,但我完全失败了。然后,我开始了一些研究,找到了解决方案,我需要获取像这样的数据集:

mnist_dataset2 = tfds.load(name = 'mnist')
mnistt = mnist_dataset2['train']

其中mnistt是我可以使用matplotlib处理和打印的数据集。

所以我的问题如下:在哪里可以获取有关可以获取的tfds.load()类型的信息,以及如何根据需要正确地操作它们? (并且可以像我一样在张量流中从初学者开始扩展)。

3 个答案:

答案 0 :(得分:0)

tfds.load方法的主调用包含您需要的所有内容:

mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
  • name="mnist"->您正在指定要使用的构建器(错误)
  • with_info=True->您要让tfds.load返回info对象,该对象包含有关返回数据集的所有您需要了解的
  • as_supervised=True->您要让tfds.load仅获得监督学习任务所需的数据集元素(图像和标签对)。

您第一次尝试使用mnist_dataset获取数据(与matplotlib一起使用)失败了,因为您可以从中看到

print(mnist_info) #run me!

数据集包含2个不同的分割:traintest

tfds.core.DatasetInfo(
    name='mnist',
    version=1.0.0,
    description='The MNIST database of handwritten digits.',
    urls=['https://storage.googleapis.com/cvdf-datasets/mnist/'],
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
    redistribution_info=,
)

因此,tfds.load返回的对象是字典

{
   "train": <train dataset>,
   "test": <test dataset>
}

实际上,在示例的下一行中,您是通过以下方式提取“ train”和“ test”数据集的:

mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']

mnist_info对象中,您可以获得处理数据集所需的所有信息:分割数,数据类型(例如,“图像”是dtype为tf.uint8的28x28x1图像)等。 ..

答案 1 :(得分:0)

使用此代码加载mnist时出现错误

mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)

错误 init ()缺少2个必需的位置参数:“ op”和“ message”

源Udemy课程

答案 2 :(得分:0)

尝试

x_train, y_train = Next(iter(mnist_train))

然后绘制x_train