访问从tensorflow_datasets加载的CIFAR-100中的'coarse_label'

时间:2019-10-23 12:58:29

标签: python tensorflow keras tensorflow-datasets

我正在使用tensorflow_datasets (tfds doc)

加载CIFAR-100
train, test = tfds.load(name="cifar100:3.*.*", split=["train", "test"], as_supervised=True)

CIFAR-100既有标签(100个类),也有arrough_label(20个类),如上面链接的文档所示。可以轻松访问标签,例如:

for image, label in train:
     # ... the label here is the actual label, not the coarse_label

但是,我正计划基于arough_label进行操作,例如,对其进行过滤或将其用作Keras分类器中的标签。

我该如何访问粗糙标签?

1 个答案:

答案 0 :(得分:0)

我找到了解决方案。如果我没有按照监督的方式加载,即删除了as_supervised=True

train, test = tfds.load(name="cifar100:3.*.*", split=["train", "test"])

,例如,我可以从字典中获得粗糙标签。

for item in train:
   print(item['coarse_label'])

像这样,我将能够重塑数据集。 random_labels可用于分类。但是,即使我对标签感兴趣,也必须加载as_supervised=False对我来说仍然很不自然。如果有人有更好的解决方案,我很乐意接受该答案。