我将以下功能写入培训TFRecord:
feature = {'label': _int64_feature(gt),
'image': _bytes_feature(tf.compat.as_bytes(im.tostring())),
'height': _int64_feature(h),
'width': _int64_feature(w)}
我正在阅读,就像:
train_dataset = tf.data.TFRecordDataset(train_file)
train_dataset = train_dataset.map(parse_func)
train_dataset = train_dataset.shuffle(buffer_size=1)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.prefetch(batch_size)
而我的parse_func看起来像这样:
def parse_func(ex):
feature = {'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64)}
features = tf.parse_single_example(ex, features=feature)
image = tf.decode_raw(features['image'], tf.uint8)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
im_shape = tf.stack([width, height])
image = tf.reshape(image, im_shape)
label = tf.cast(features['label'], tf.int32)
return image, label
现在,我要获得图像和标签的形状,例如:
image.get_shape().as_list()
打印
[无,无,无]
而不是
[无,224,224](图像的大小(批量,宽度,高度))
有什么函数可以给我这些张量的大小吗?
答案 0 :(得分:0)
由于您的地图函数“ parse_func”作为操作的一部分在图形中,并且它不知道输入的固定大小并标记为先验,因此使用get_shape()不会返回预期的固定形状。
如果您的图片,标签形状是固定的,例如,您可以尝试重塑图片,具有已知大小的标签(这实际上将不会做任何事情,但是会显式设置标签的大小输出张量)。
例如 图片= tf.reshape(图片,[224,224])
有了它,您应该能够按预期获得get_shape()结果。
答案 1 :(得分:0)
另一种解决方案是存储编码的图像,而不是解码的原始字节,这样,您只需在读取tfrecords时使用tensorflow将图像解码回去,这也将帮助您节省存储空间,这样您就可以获取从张量中得出图像的形状。
# Load your image as you would normally do then do:
# Convert the image to raw bytes.
img_bytes = tf.io.encode_jpeg(img).numpy()
# Create a dict with the data we want to save in the
# TFRecords file. You can add more relevant data here.
data = \
{'image': wrap_bytes(img_bytes),
'label': wrap_int64(label)}
# Wrap the data as TensorFlow Features.
feature = tf.train.Features(feature=data)
# Wrap again as a TensorFlow Example.
example = tf.train.Example(features=feature)
# Serialize the data.
serialized = example.SerializeToString()
# Write the serialized data to the TFRecords file.
writer.write(serialized)
features = \
{
'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)
}
# Parse the serialized data so we get a dict with our data.
parsed_example = tf.io.parse_single_example(serialized=serialized,
features=features)
# Get the image as raw bytes.
image_raw = parsed_example['image']
# Decode the raw bytes so it becomes a tensor with type.
image = tf.io.decode_jpeg(image_raw)
image = tf.cast(image, tf.float32) # optional
# Get the label associated with the image.
label = parsed_example['label']
# The image and label are now correct TensorFlow types.
return image, label