tf.data.dataset:如何将形状分配给保证输出一定形状的数据集(形状未定义)?

时间:2019-11-16 12:18:52

标签: python tensorflow tensorflow-datasets

我有一个tf2数据集API dataset,该API经历了多个map操作,随后是tf.image.resize,该操作会不断输出形状(300, 300),即,确保每个记录在之后都具有此形状所有地图操作。但是,这不是固有地推断的,因此Tensor Spec显示<undefined>, <undefined>的形状。如果未定义的形状数据集传递给具有预定义输入形状的模型,则会引发错误。

一些搜索帮助我找到了此功能tf.contrib.data.assert_element_shapeIssue #16052

dataset = dataset.apply(tf.data.experimental.assert_element_shape(custom_shape))

但是此功能已在tf2中删除,并且文档不建议使用其他内容代替assert_element_shape。 什么是等效的?或如何为保证输出特定形状的数据集分配形状?

1 个答案:

答案 0 :(得分:0)

由于某些原因,在我添加了set_shape的地图函数中添加tf.image.resize无效。

# does not work
function my_map_function(image, label):
    # some image operations here
    image = tf.image.resize(image, size=[300, 300])
    image.set_shape((300, 300, 3))
    return image, label

但是当我创建一个单独的map函数时,它可以工作:

# works
def set_shapes(image, label):
    image.set_shape((300, 300, 3))
    label.set_shape([])
    return image, label

也许我会坚持下去,直到直接将assert_element_shapeset_element_shape添加为单独的函数