将tfrecords中的原始字节解码为tf.feature_column.numeric_column feature

时间:2017-10-18 09:38:47

标签: python tensorflow tfrecord

我有一个tfrecords文件存储图像为字节串。我想将此功能列定义为tf.feature_column.numeric_column("image", shape=[64, 64], dtype=tf.float32),但由于它未在tfrecords文件中存储为float_list,因此无效。

然后我尝试使用我定义为。

的numeric_column的normalizer_fn参数
def decode(image_bytestring):
    img = tf.reshape(tf.decode_raw(image_bytestring, tf.uint8), [28, 28])
    img = tf.cast(img, tf.float32)
    return img

...

examples = tf.parse_example(
            serialized_batch,
            tf.feature_column.make_parse_example_spec(feature_columns))

然而,第一个问题是这个feature_column生成的解析规范FixedLenFeature(shape=(28, 28), dtype=tf.float32, default_value=None)表示当它实际存储为导致错误的字符串时解析float32。因此不使用解码功能。

当使用tf.feature_column而不是将图像存储为tfrecord中的float_list时,有没有办法解决这个问题?

似乎有一个静态类型系统可以很好地保证映射函数中正确的特征类型。

1 个答案:

答案 0 :(得分:1)

也许您可以将图像存储为字符串字节,并按照常用方式读取图像?

feature_map = { 'image': tf.FixedLenFeature([], dtype=tf.string,default_value='') }
features = tf.parse_single_example(example_serialized, feature_map)
image_buffer = features['image']
image = tf.image.decode_image(image_buffer, ...)