我正在学习如何从Udemy课程中在tensorflow 2.0和Keras中从头开始创建MNIST模型。
所以,我得到了mnist数据集,如下所示
mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']
一切都很好,即使我测试模型的准确性达到97%,我也很高兴。
当我尝试做一些不同于课程的事情时,问题就开始了。我尝试使用matplotlib plt.imshow()
从mnist_dataset打印一些示例,但我完全失败了。然后,我开始了一些研究,找到了解决方案,我需要获取像这样的数据集:
mnist_dataset2 = tfds.load(name = 'mnist')
mnistt = mnist_dataset2['train']
其中mnistt
是我可以使用matplotlib处理和打印的数据集。
所以我的问题如下:在哪里可以获取有关可以获取的tfds.load()类型的信息,以及如何根据需要正确地操作它们? (并且可以像我一样在张量流中从初学者开始扩展)。
答案 0 :(得分:0)
tfds.load
方法的主调用包含您需要的所有内容:
mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
name="mnist"
->您正在指定要使用的构建器(错误)with_info=True
->您要让tfds.load
返回info
对象,该对象包含有关返回数据集的所有您需要了解的 as_supervised=True
->您要让tfds.load
仅获得监督学习任务所需的数据集元素(图像和标签对)。您第一次尝试使用mnist_dataset
获取数据(与matplotlib
一起使用)失败了,因为您可以从中看到
print(mnist_info) #run me!
数据集包含2个不同的分割:train
和test
。
tfds.core.DatasetInfo(
name='mnist',
version=1.0.0,
description='The MNIST database of handwritten digits.',
urls=['https://storage.googleapis.com/cvdf-datasets/mnist/'],
features=FeaturesDict({
'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
}),
total_num_examples=70000,
splits={
'test': 10000,
'train': 60000,
},
supervised_keys=('image', 'label'),
citation="""@article{lecun2010mnist,
title={MNIST handwritten digit database},
author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
journal={ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist},
volume={2},
year={2010}
}""",
redistribution_info=,
)
因此,tfds.load
返回的对象是字典:
{
"train": <train dataset>,
"test": <test dataset>
}
实际上,在示例的下一行中,您是通过以下方式提取“ train”和“ test”数据集的:
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']
从mnist_info
对象中,您可以获得处理数据集所需的所有信息:分割数,数据类型(例如,“图像”是dtype为tf.uint8的28x28x1图像)等。 ..
答案 1 :(得分:0)
使用此代码加载mnist时出现错误
mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
错误 init ()缺少2个必需的位置参数:“ op”和“ message”
源Udemy课程
答案 2 :(得分:0)
尝试
x_train, y_train = Next(iter(mnist_train))
然后绘制x_train