tensorflow_dataset使用dataset.map进行图像变换

时间:2019-07-03 17:07:59

标签: tensorflow tensorflow-datasets

我正在尝试使用Python中的cifar100 dataset库加载tesorflow_dataset。用.load()加载数据后,我正尝试用.map()将图像转换为设置的大小,地图内部的lambda给出了

  

TypeError:()缺少2个必需的位置参数:   'coarse_label'和'label'

运行我的代码时。

在将标签信息保留在数据中的同时转换这些图像的最佳方法是什么?我不确定lambda函数如何与数据集交互。

这是通过tensorflow 2.0.0b1,tensorflow数据集1.0.2和Python 3.7.3完成的

def transform_images(x_train, size):
    x_train = tf.image.resize(x_train, (size, size))
    x_train = x_train / 255
    return x_train

train_dataset = tfds.load(name="cifar100", split=tfds.Split.TRAIN)
train_dataset = train_dataset.map(lambda image, coarse_label, label: 
        (dataset.transform_images(image, FLAGS.size), coarse_label, label))

1 个答案:

答案 0 :(得分:1)

train_dataset的每一行都是字典,而不是元组。因此,您不能像lambda那样使用lambda image, coarse_label, label

import tensorflow as tf
import tensorflow_datasets as tfds

train_dataset = tfds.load(name="cifar100", split=tfds.Split.TRAIN)
print(train_dataset.output_shapes)

# {'image': TensorShape([32, 32, 3]), 'label': TensorShape([]), 'coarse_label': TensorShape([])}

您应该按如下方式使用它:

def transform_images(row, size):
    x_train = tf.image.resize(row['image'], (size, size))
    x_train = x_train  / 255
    return x_train, row['coarse_label'], row['label']

train_dataset = train_dataset.map(lambda row:transform_images(row, 16))
print(train_dataset.output_shapes)

# (TensorShape([16, 16, 3]), TensorShape([]), TensorShape([]))