如何从TFRecordDataset获取张量的形状

时间:2018-08-24 10:10:09

标签: tensorflow tensor tensorflow-datasets tfrecord

我将以下功能写入培训TFRecord:

feature = {'label': _int64_feature(gt),
           'image': _bytes_feature(tf.compat.as_bytes(im.tostring())),
           'height': _int64_feature(h),
           'width': _int64_feature(w)}

我正在阅读,就像:

train_dataset = tf.data.TFRecordDataset(train_file)
train_dataset = train_dataset.map(parse_func)
train_dataset = train_dataset.shuffle(buffer_size=1)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.prefetch(batch_size)

而我的parse_func看起来像这样:

def parse_func(ex):
    feature = {'image': tf.FixedLenFeature([], tf.string),
               'label': tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
               'height': tf.FixedLenFeature([], tf.int64),
               'width': tf.FixedLenFeature([], tf.int64)}
    features = tf.parse_single_example(ex, features=feature)
    image = tf.decode_raw(features['image'], tf.uint8)
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)
    im_shape = tf.stack([width, height])
    image = tf.reshape(image, im_shape)
    label = tf.cast(features['label'], tf.int32)
    return image, label

现在,我要获得图像标签的形状,例如:

image.get_shape().as_list()

打印
[无,无,无]
而不是
[无,224,224](图像的大小(批量,宽度,高度))

有什么函数可以给我这些张量的大小吗?

2 个答案:

答案 0 :(得分:0)

由于您的地图函数“ parse_func”作为操作的一部分在图形中,并且它不知道输入的固定大小并标记为先验,因此使用get_shape()不会返回预期的固定形状。

如果您的图片,标签形状是固定的,例如,您可以尝试重塑图片,具有已知大小的标签(这实际上将不会做任何事情,但是会显式设置标签的大小输出张量)。

例如 图片= tf.reshape(图片,[224,224])

有了它,您应该能够按预期获得get_shape()结果。

答案 1 :(得分:0)

另一种解决方案是存储编码的图像,而不是解码的原始字节,这样,您只需在读取tfrecords时使用tensorflow将图像解码回去,这也将帮助您节省存储空间,这样您就可以获取从张量中得出图像的形状。

    # Load your image as you would normally do then do:

    # Convert the image to raw bytes.
    img_bytes = tf.io.encode_jpeg(img).numpy()

    # Create a dict with the data we want to save in the
    # TFRecords file. You can add more relevant data here.
    data = \
    {'image': wrap_bytes(img_bytes),
     'label': wrap_int64(label)}

    # Wrap the data as TensorFlow Features.
    feature = tf.train.Features(feature=data)

    # Wrap again as a TensorFlow Example.
    example = tf.train.Example(features=feature)

    # Serialize the data.
    serialized = example.SerializeToString()
            
    # Write the serialized data to the TFRecords file.
    writer.write(serialized) 

然后阅读,您可以使用:

    features = \
        {
            'image': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64)            
        }

    # Parse the serialized data so we get a dict with our data.
    parsed_example = tf.io.parse_single_example(serialized=serialized,
                                             features=features)

    # Get the image as raw bytes.
    image_raw = parsed_example['image']

    # Decode the raw bytes so it becomes a tensor with type.
    image = tf.io.decode_jpeg(image_raw)
    
    image = tf.cast(image, tf.float32) # optional
    
    # Get the label associated with the image.
    label = parsed_example['label']
    
    # The image and label are now correct TensorFlow types.
    return image, label