我正在尝试准备数据集以进行图像分割。我使用以下代码将所有图像及其关联的注释转换为.tfrecord
文件:
writer = tf.python_io.TFRecordWriter(tfrecords_filename)
for img_path, annotation_path in filename_pairs:
img = np.array(Image.open(img_path))
annotation = np.array(Image.open(annotation_path))
height = img.shape[0]
width = img.shape[1]
img_raw = img.tostring()
annotation_raw = annotation.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(height),
'width': _int64_feature(width),
'image_raw': _bytes_feature(img_raw),
'mask_raw': _bytes_feature(annotation_raw)}))
writer.write(example.SerializeToString())
现在,我正在尝试将这些记录加载到TF数据集中:
dataset = tf.data.TFRecordDataset(training_filenames).shuffle(1000).repeat(4).batch(32)
现在,如果我尝试在此dataset
中显示第一个图像/注释对,则可以按预期工作:
batch = next(iter(dataset))
tensor = batch[0]
image, annotation = _parse_function(tensor)
annotation = np.squeeze(annotation.numpy()[:, :], axis=2)
plt.figure()
plt.imshow(image.numpy())
plt.imshow(annotation, alpha=0.5)
plt.show()
我在其中使用_parse_function
预处理记录以提取特征(我有意在急切的执行模式下使用TensorFlow ):
def _parse_function(example_proto):
features = {'height': tf.FixedLenFeature(1, tf.int64),
'width': tf.FixedLenFeature(1, tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
'mask_raw': tf.FixedLenFeature([], tf.string)}
parsed_features = tf.parse_single_example(example_proto, features)
annotation = tf.decode_raw(parsed_features['mask_raw'], tf.uint8)
height = tf.cast(parsed_features['height'], tf.int32)
width = tf.cast(parsed_features['width'], tf.int32)
height = height.numpy()[0]
width = width.numpy()[0]
image = tf.decode_raw(parsed_features['image_raw'], tf.uint8)
image = tf.reshape(image, tf.stack([height, width, 3]))
annotation = tf.reshape(annotation, tf.stack([height, width, 1]))
return image, annotation
当然,我宁愿将整个dataset
变成可以直接用于训练细分模型的东西。
但是,如果我尝试使用dataset
对整个dataset.map(_parse_function)
进行预处理以将其转换为一组功能,则似乎正在馈入example_proto
的{{1}}与做_parse_function
时得到的不同。更准确地说,它是等级0的张量(因此只是一个量级),因此无法正确提取特征。
我对TF还是比较陌生,不十分了解为什么是这种情况,也不知道这个张量代表什么。
next(iter(dataset))[0]
是否批量调用回调函数而不是基础示例?我曾尝试删除map
,但是文档说默认行为是生成大小为1的批次,但这不一定能解决问题。
任何帮助将不胜感激!