如何迭代tensorflow数据集?

时间:2020-02-08 04:27:30

标签: python tensorflow tensorflow-datasets

我正在尝试从tfrecords建立数据管道,这是我的代码

def _parse_image_function(example_proto):
    keys_to_features = {
        'image/encoded': tf.io.FixedLenFeature((), tf.string),
        'image/format': tf.io.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height': tf.io.FixedLenFeature([1], tf.int64),
        'image/width': tf.io.FixedLenFeature([1], tf.int64),
        'image/channels': tf.io.FixedLenFeature([1], tf.int64),
        'image/shape': tf.io.FixedLenFeature([3], tf.int64),
        'image/object/bbox/xmin': tf.io.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymin': tf.io.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/xmax': tf.io.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymax': tf.io.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/label': tf.io.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/difficult': tf.io.VarLenFeature(dtype=tf.int64),
        'image/object/bbox/truncated': tf.io.VarLenFeature(dtype=tf.int64),
    }

    example = tf.io.parse_single_example(example_proto, keys_to_features)
    image = tf.io.decode_raw(example['image/encoded'], tf.int32)
    return image

然后,我在解码后得到图像

    for img in train_ds:
        print(img.numpy())

但是我遇到了错误

tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __inference_Dataset_map__parse_image_function_78}} Input to DecodeRaw has length 286478 that is not a multiple of 4, the size of int32
     [[{{node DecodeRaw}}]] [Op:IteratorGetNextSync]

我该如何解决?

1 个答案:

答案 0 :(得分:1)

如评论中所述,错误消息是关于解码错误,而不是迭代问题。您正在创建数据集对象并对其进行正确的迭代。

image = tf.io.decode_raw(example['image/encoded'], tf.int32)告诉TensorFlow解码存储在该密钥中的数据,作为int32 s的张量。也就是32位整数的原始二进制值,例如,带有.data的NumPy数组中dtype=np.int32的内容。

由于您已经读入.jpg文件的二进制内容,因此我假设您在该键值下具有JPEG图像的二进制文件。对于这种情况,您应该改用decode_jpeg方法。您应该使用:

image = tf.io.decode_jpeg(example['image/encoded'])

decode_jpeg还为您提供了一些有关如何解码JPEG数据的选项(例如,仅灰度级,色度升采样方法)。 here

提供了decode_jpeg的完整文档。

此外,TensorFlow还提供了decode_image,它根据二进制数据自动调用正确的受支持图像格式解码器。参见文档here