tf.scatter_add和tf.scatter_nd都允许indices
为矩阵。从tf.scatter_nd的文档中很明显,indices
的最后一个维度包含用于索引形状为shape
的张量的值。 indices
的其他维度定义了要散布的元素/切片的数量。假设updates
的排名为N
。 k
的前indices
个维度(最后一个维度除外)应与k
的前updates
个维度匹配。 (N-k)
的最后updates
尺寸应与(N-k)
的最后shape
尺寸匹配。
这意味着tf.scatter_nd
可用于执行N
维散射。但是,tf.scatter_add
也将矩阵作为indices
。但是,尚不清楚indices
的哪个维度对应于要执行的分散数量,以及这些维度如何与updates
对齐。有人可以通过示例提供清晰的解释吗?
答案 0 :(得分:1)
@shaunshd,我终于完全理解tf.scatter_nd _ *()参数中的3个张量关系,尤其是当索引具有多维时。例如: 索引= tf.constant([[[0,0,0],[1,1,1],[2,2,2],[3,3,3],[3,3,2]],dtype = tf.int32)
请不要指望tf.rank(indices)> 2,tf.rank(indices)== 2永远为真;
以下是我的测试代码,以显示比tensroflow官方网站中提供的示例更复杂的测试用例:
def testScatterNDUpdate(self):
ref = tf.Variable(np.zeros(shape=[4, 4, 4], dtype=np.float32))
indices = tf.constant([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [3,3,2]], dtype=tf.int32)
updates = tf.constant([1,2,3,4,5], dtype=tf.float32)
#shape = (4,4,4)
print(tf.tensor_scatter_nd_update(ref, indices, updates))
print(ref.scatter_nd_update(indices, updates))
#print(updates.shape[-1]==shape[-1], updates.shape[0]<=shape[0])
#conditions are:
# updates.shape[0]==indices[0]
# indices[1]<=len(shape)
# tf.rank(indices)==2
您还可以使用以下伪代码理解索引:
def scatter_nd_update(ref, indices, updates):
for i in range(tf.shape(indices)[0]):
ref[indices[i]]=updates[i]
return ref
与numpy的花式索引功能相比,tensorflow的索引功能仍然非常难以使用,并且具有不同的使用样式,尚未与numpy统一。希望在tf3.x中情况会更好