我一直能够毫无问题地使用 tf.tensor_scatter_nd_update
写入张量,但我无法弄清楚为什么它不能使用某些特定的张量。
举一个简单的例子,假设我想根据布尔掩码 input=[[0 0 0]]
将 update=[[1 2 3]]
中的某些值设置为 mask=[[1 0 1]]
。
我只会做:
input=tf.tensor_scatter_nd_update(input,tf.where(mask),update)
期望运算结果为 input=[[1 0 3]]
。
相反,我得到了
ValueError: Dimensions [2,2) of input[shape=[1,3]] = [] must match dimensions [1,2) of updates[shape=[1,3]] = [3]: Shapes must be equal rank, but are 0 and 1 for ... with input shapes: [1,3], [?,2], [1,3].
我真的不知道出了什么问题;即使在更复杂的情况下,我也一直能够毫无问题地使用该功能。
答案 0 :(得分:-1)
我想通了。
部分问题确实是 tf.where()
返回了一个 2-D 张量,但这起作用了,因为我还用它来生成 updates
向量:
input=input=tf.tensor_scatter_nd_update(input,tf.where(mask),tf.where(something_else))
解决方案是通过以下方式去除额外的维度:
input=input=tf.tensor_scatter_nd_update(input,tf.where(mask),tf.squeeze(tf.where(something_else)))