我正在尝试从tfrecords建立数据管道,这是我的代码
def _parse_image_function(example_proto):
keys_to_features = {
'image/encoded': tf.io.FixedLenFeature((), tf.string),
'image/format': tf.io.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/height': tf.io.FixedLenFeature([1], tf.int64),
'image/width': tf.io.FixedLenFeature([1], tf.int64),
'image/channels': tf.io.FixedLenFeature([1], tf.int64),
'image/shape': tf.io.FixedLenFeature([3], tf.int64),
'image/object/bbox/xmin': tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymin': tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/xmax': tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymax': tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/label': tf.io.VarLenFeature(dtype=tf.int64),
'image/object/bbox/difficult': tf.io.VarLenFeature(dtype=tf.int64),
'image/object/bbox/truncated': tf.io.VarLenFeature(dtype=tf.int64),
}
example = tf.io.parse_single_example(example_proto, keys_to_features)
image = tf.io.decode_raw(example['image/encoded'], tf.int32)
return image
然后,我在解码后得到图像
for img in train_ds:
print(img.numpy())
但是我遇到了错误
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __inference_Dataset_map__parse_image_function_78}} Input to DecodeRaw has length 286478 that is not a multiple of 4, the size of int32
[[{{node DecodeRaw}}]] [Op:IteratorGetNextSync]
我该如何解决?
答案 0 :(得分:1)
如评论中所述,错误消息是关于解码错误,而不是迭代问题。您正在创建数据集对象并对其进行正确的迭代。
image = tf.io.decode_raw(example['image/encoded'], tf.int32)
告诉TensorFlow解码存储在该密钥中的数据,作为int32
s的张量。也就是32位整数的原始二进制值,例如,带有.data
的NumPy数组中dtype=np.int32
的内容。
由于您已经读入.jpg
文件的二进制内容,因此我假设您在该键值下具有JPEG图像的二进制文件。对于这种情况,您应该改用decode_jpeg
方法。您应该使用:
image = tf.io.decode_jpeg(example['image/encoded'])
。
decode_jpeg
还为您提供了一些有关如何解码JPEG数据的选项(例如,仅灰度级,色度升采样方法)。 here
decode_jpeg
的完整文档。
此外,TensorFlow还提供了decode_image
,它根据二进制数据自动调用正确的受支持图像格式解码器。参见文档here。