我有一个tf2数据集API dataset
,该API经历了多个map
操作,随后是tf.image.resize
,该操作会不断输出形状(300, 300)
,即,确保每个记录在之后都具有此形状所有地图操作。但是,这不是固有地推断的,因此Tensor Spec显示<undefined>, <undefined>
的形状。如果未定义的形状数据集传递给具有预定义输入形状的模型,则会引发错误。
一些搜索帮助我找到了此功能tf.contrib.data.assert_element_shape和Issue #16052:
dataset = dataset.apply(tf.data.experimental.assert_element_shape(custom_shape))
但是此功能已在tf2中删除,并且文档不建议使用其他内容代替assert_element_shape。 什么是等效的?或如何为保证输出特定形状的数据集分配形状?
答案 0 :(得分:0)
由于某些原因,在我添加了set_shape
的地图函数中添加tf.image.resize
无效。
# does not work
function my_map_function(image, label):
# some image operations here
image = tf.image.resize(image, size=[300, 300])
image.set_shape((300, 300, 3))
return image, label
但是当我创建一个单独的map函数时,它可以工作:
# works
def set_shapes(image, label):
image.set_shape((300, 300, 3))
label.set_shape([])
return image, label
也许我会坚持下去,直到直接将assert_element_shape
或set_element_shape
添加为单独的函数