如何将函数正确映射到数据集的每个记录

时间:2018-09-25 10:32:11

标签: python tensorflow

目标

我正在尝试准备数据集以进行图像分割。我使用以下代码将所有图像及其关联的注释转换为.tfrecord文件:

writer = tf.python_io.TFRecordWriter(tfrecords_filename)

for img_path, annotation_path in filename_pairs:
    img = np.array(Image.open(img_path))
    annotation = np.array(Image.open(annotation_path))
    height = img.shape[0]
    width = img.shape[1]

    img_raw = img.tostring()
    annotation_raw = annotation.tostring()

    example = tf.train.Example(features=tf.train.Features(feature={
        'height': _int64_feature(height),
        'width': _int64_feature(width),
        'image_raw': _bytes_feature(img_raw),
        'mask_raw': _bytes_feature(annotation_raw)}))

    writer.write(example.SerializeToString())

现在,我正在尝试将这些记录加载到TF数据集中:

dataset = tf.data.TFRecordDataset(training_filenames).shuffle(1000).repeat(4).batch(32)

实验

现在,如果我尝试在此dataset中显示第一个图像/注释对,则可以按预期工作:

batch = next(iter(dataset))
tensor = batch[0]

image, annotation = _parse_function(tensor)
annotation = np.squeeze(annotation.numpy()[:, :], axis=2)
plt.figure()
plt.imshow(image.numpy())
plt.imshow(annotation, alpha=0.5)
plt.show()

我在其中使用_parse_function预处理记录以提取特征(我有意在急切的执行模式下使用TensorFlow ):

def _parse_function(example_proto):
    features = {'height': tf.FixedLenFeature(1, tf.int64),
                'width': tf.FixedLenFeature(1, tf.int64),
                'image_raw': tf.FixedLenFeature([], tf.string),
                'mask_raw': tf.FixedLenFeature([], tf.string)}
    parsed_features = tf.parse_single_example(example_proto, features)

    annotation = tf.decode_raw(parsed_features['mask_raw'], tf.uint8)

    height = tf.cast(parsed_features['height'], tf.int32)
    width = tf.cast(parsed_features['width'], tf.int32)
    height = height.numpy()[0]
    width = width.numpy()[0]

    image = tf.decode_raw(parsed_features['image_raw'], tf.uint8)

    image = tf.reshape(image, tf.stack([height, width, 3]))
    annotation = tf.reshape(annotation, tf.stack([height, width, 1]))

    return image, annotation

实际问题

当然,我宁愿将整个dataset变成可以直接用于训练细分模型的东西。

但是,如果我尝试使用dataset对整个dataset.map(_parse_function)进行预处理以将其转换为一组功能,则似乎正在馈入example_proto的{​​{1}}与做_parse_function时得到的不同。更准确地说,它是等级0的张量(因此只是一个量级),因此无法正确提取特征。

我对TF还是比较陌生,不十分了解为什么是这种情况,也不知道这个张量代表什么。

next(iter(dataset))[0]是否批量调用回调函数而不是基础示例?我曾尝试删除map,但是文档说默认行为是生成大小为1的批次,但这不一定能解决问题。

任何帮助将不胜感激!

0 个答案:

没有答案