我正在尝试使用Python中的cifar100 dataset
库加载tesorflow_dataset
。用.load()
加载数据后,我正尝试用.map()
将图像转换为设置的大小,地图内部的lambda给出了
TypeError:()缺少2个必需的位置参数: 'coarse_label'和'label'
运行我的代码时。
在将标签信息保留在数据中的同时转换这些图像的最佳方法是什么?我不确定lambda函数如何与数据集交互。
这是通过tensorflow 2.0.0b1,tensorflow数据集1.0.2和Python 3.7.3完成的
def transform_images(x_train, size):
x_train = tf.image.resize(x_train, (size, size))
x_train = x_train / 255
return x_train
train_dataset = tfds.load(name="cifar100", split=tfds.Split.TRAIN)
train_dataset = train_dataset.map(lambda image, coarse_label, label:
(dataset.transform_images(image, FLAGS.size), coarse_label, label))
答案 0 :(得分:1)
train_dataset
的每一行都是字典,而不是元组。因此,您不能像lambda
那样使用lambda image, coarse_label, label
。
import tensorflow as tf
import tensorflow_datasets as tfds
train_dataset = tfds.load(name="cifar100", split=tfds.Split.TRAIN)
print(train_dataset.output_shapes)
# {'image': TensorShape([32, 32, 3]), 'label': TensorShape([]), 'coarse_label': TensorShape([])}
您应该按如下方式使用它:
def transform_images(row, size):
x_train = tf.image.resize(row['image'], (size, size))
x_train = x_train / 255
return x_train, row['coarse_label'], row['label']
train_dataset = train_dataset.map(lambda row:transform_images(row, 16))
print(train_dataset.output_shapes)
# (TensorShape([16, 16, 3]), TensorShape([]), TensorShape([]))