使用py_func时,数据集API不会为其输出张量传递维度信息

时间:2018-02-16 09:40:28

标签: python tensorflow tensorflow-datasets

要重现我的问题,请先尝试(使用py_func进行映射):

import tensorflow as tf
import numpy as np
def image_parser(image_name):
    a = np.array([1.0,2.0,3.0], dtype=np.float32)
    return a

images = [[1,2,3],[4,5,6]]
im_dataset = tf.data.Dataset.from_tensor_slices(images)
im_dataset = im_dataset.map(lambda image:tuple(tf.py_func(image_parser, [image], [tf.float32])), num_parallel_calls = 2)
im_dataset = im_dataset.prefetch(4)
iterator = im_dataset.make_initializable_iterator()
print(im_dataset.output_shapes)

它会给你(TensorShape(None),)

但是,如果你尝试这个(使用直接tensorflow映射而不是py_func):

import tensorflow as tf
import numpy as np

def image_parser(image_name)
    return image_name

images = [[1,2,3],[4,5,6]]
im_dataset = tf.data.Dataset.from_tensor_slices(images)
im_dataset = im_dataset.map(image_parser)
im_dataset = im_dataset.prefetch(4)
iterator = im_dataset.make_initializable_iterator()
print(im_dataset.output_shapes)

它将为您提供精确的张量维度(3,)

1 个答案:

答案 0 :(得分:3)

这是tf.py_func的一般问题,因为TensorFlow无法推断输出形状本身,例如参见this answer

如果需要,可以通过移动解析函数中的tf.py_func来自己设置形状:

def parser(x):
    a = np.array([1.0,2.0,3.0])
    y = tf.py_func(lambda: a, [], tf.float32)
    y.set_shape((3,))
    return y

dataset = tf.data.Dataset.range(10)
dataset = dataset.map(parser)
print(dataset.output_shapes)  # will correctly print (3,)