创建TF数据集时无法解析TFRecords

时间:2019-03-04 00:37:01

标签: python tensorflow tensorflow-datasets

我正在尝试编写代码以解析TFRecords并创建TF数据集。我从图像列表创建TFRecords文件,并且能够读回该文件并成功解码我的图像。我的代码基于此blog中的示例。但是,当我尝试读取TFRecords文件并创建TF数据集时,它将失败,并显示以下错误:

ValueError: Argument must be a dense tensor: FixedLenFeature(shape=[], dtype=tf.int64, default_value=None) - got shape [3], but wanted [3, 0]

以下是尝试创建数据集的代码摘要:

 dataset = tf.data.TFRecordDataset(fnames)
 dataset = dataset.map(parse_tfrec)

其中parse_tfrec是解析单个原始记录的函数:

 def parse_tfrec(example_proto):
    features={
    'height': tf.FixedLenFeature([], tf.int64, default_value=IMG_SHAPE[0]),
    'width': tf.FixedLenFeature([], tf.int64, default_value=IMG_SHAPE[1]),
    'depth': tf.FixedLenFeature([], tf.int64, default_value=IMG_SHAPE[2]),
    'label': tf.FixedLenFeature([], tf.int64, default_value=0),
    'image': tf.FixedLenFeature([], tf.string, default_value=''),
    }
    parsed_features = tf.parse_single_example(example_proto, features)
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)
    depth = tf.cast(features['depth'], tf.int32)
    label = tf.cast(features['label'], tf.int32)
    image = tf.decode_raw(features['image'], tf.uint8)
    image_shape = tf.pack([height, width, depth])
    image = tf.reshape(image, image_shape)
    return image, label

当代码尝试从TFRecords(或任何其他存储的整数)解析height时,代码将失败。而且,我不确定我是否了解有关形状的失败消息。

有什么建议吗?

1 个答案:

答案 0 :(得分:1)

您能详细说明错误发生在哪一行吗?它是否出现在“ parse_single_example”行上?还是在下一行?

我注意到的一件事是,在您的强制转换语句中,您使用的是features字典,而不是parsed_features

将您的代码更改为这样可以解决您的问题:

height = tf.cast(parsed_features['height'], tf.int32)

让我知道问题是否仍然存在。最近,我自己花了很长时间调试tfrecords :)最初可能很难理解它们,但是最终,我能够在批处理生成时间中获得巨大的性能提升。