要重现我的问题,请先尝试(使用py_func进行映射):
import tensorflow as tf
import numpy as np
def image_parser(image_name):
a = np.array([1.0,2.0,3.0], dtype=np.float32)
return a
images = [[1,2,3],[4,5,6]]
im_dataset = tf.data.Dataset.from_tensor_slices(images)
im_dataset = im_dataset.map(lambda image:tuple(tf.py_func(image_parser, [image], [tf.float32])), num_parallel_calls = 2)
im_dataset = im_dataset.prefetch(4)
iterator = im_dataset.make_initializable_iterator()
print(im_dataset.output_shapes)
它会给你(TensorShape(None),)
但是,如果你尝试这个(使用直接tensorflow映射而不是py_func):
import tensorflow as tf
import numpy as np
def image_parser(image_name)
return image_name
images = [[1,2,3],[4,5,6]]
im_dataset = tf.data.Dataset.from_tensor_slices(images)
im_dataset = im_dataset.map(image_parser)
im_dataset = im_dataset.prefetch(4)
iterator = im_dataset.make_initializable_iterator()
print(im_dataset.output_shapes)
它将为您提供精确的张量维度(3,)
答案 0 :(得分:3)
这是tf.py_func
的一般问题,因为TensorFlow无法推断输出形状本身,例如参见this answer。
如果需要,可以通过移动解析函数中的tf.py_func
来自己设置形状:
def parser(x):
a = np.array([1.0,2.0,3.0])
y = tf.py_func(lambda: a, [], tf.float32)
y.set_shape((3,))
return y
dataset = tf.data.Dataset.range(10)
dataset = dataset.map(parser)
print(dataset.output_shapes) # will correctly print (3,)