tensor_scatter_nd_update ValueError:形状必须具有相同的等级,但为 0 和 1

时间:2021-01-11 12:49:11

标签: python tensorflow

我一直能够毫无问题地使用 tf.tensor_scatter_nd_update 写入张量,但我无法弄清楚为什么它不能使用某些特定的张量。

举一个简单的例子,假设我想根据布尔掩码 input=[[0 0 0]]update=[[1 2 3]] 中的某些值设置为 mask=[[1 0 1]]。 我只会做:

input=tf.tensor_scatter_nd_update(input,tf.where(mask),update)

期望运算结果为 input=[[1 0 3]]

相反,我得到了

ValueError: Dimensions [2,2) of input[shape=[1,3]] = [] must match dimensions [1,2) of updates[shape=[1,3]] = [3]: Shapes must be equal rank, but are 0 and 1 for ... with input shapes: [1,3], [?,2], [1,3].

我真的不知道出了什么问题;即使在更复杂的情况下,我也一直能够毫无问题地使用该功能。

1 个答案:

答案 0 :(得分:-1)

我想通了。

部分问题确实是 tf.where() 返回了一个 2-D 张量,但这起作用了,因为我还用它来生成 updates 向量:

input=input=tf.tensor_scatter_nd_update(input,tf.where(mask),tf.where(something_else))

解决方案是通过以下方式去除额外的维度:

input=input=tf.tensor_scatter_nd_update(input,tf.where(mask),tf.squeeze(tf.where(something_else)))