当索引是矩阵时,tf.scatter_add和tf.scatter_nd有什么区别?

时间:2019-06-07 09:27:36

标签: tensorflow

tf.scatter_addtf.scatter_nd都允许indices为矩阵。从tf.scatter_nd的文档中很明显,indices的最后一个维度包含用于索引形状为shape的张量的值。 indices的其他维度定义了要散布的元素/切片的数量。假设updates的排名为Nk的前indices个维度(最后一个维度除外)应与k的前updates个维度匹配。 (N-k)的最后updates尺寸应与(N-k)的最后shape尺寸匹配。

这意味着tf.scatter_nd可用于执行N维散射。但是,tf.scatter_add也将矩阵作为indices。但是,尚不清楚indices的哪个维度对应于要执行的分散数量,以及这些维度如何与updates对齐。有人可以通过示例提供清晰的解释吗?

1 个答案:

答案 0 :(得分:1)

@shaunshd,我终于完全理解tf.scatter_nd _ *()参数中的3个张量关系,尤其是当索引具有多维时。例如: 索引= tf.constant([[[0,0,0],[1,1,1],[2,2,2],[3,3,3],[3,3,2]],dtype = tf.int32)

请不要指望tf.rank(indices)> 2,tf.rank(indices)== 2永远为真;

以下是我的测试代码,以显示比tensroflow官方网站中提供的示例更复杂的测试用例:

def testScatterNDUpdate(self):
    ref = tf.Variable(np.zeros(shape=[4, 4, 4], dtype=np.float32))
    indices = tf.constant([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [3,3,2]], dtype=tf.int32)
    updates = tf.constant([1,2,3,4,5], dtype=tf.float32)
    #shape = (4,4,4)
    print(tf.tensor_scatter_nd_update(ref, indices, updates))
    print(ref.scatter_nd_update(indices, updates))
    #print(updates.shape[-1]==shape[-1], updates.shape[0]<=shape[0])
    #conditions are:
    #      updates.shape[0]==indices[0]
    #      indices[1]<=len(shape)
    #      tf.rank(indices)==2

您还可以使用以下伪代码理解索引:

def scatter_nd_update(ref, indices, updates):
    for i in range(tf.shape(indices)[0]):
        ref[indices[i]]=updates[i]
    return ref

与numpy的花式索引功能相比,tensorflow的索引功能仍然非常难以使用,并且具有不同的使用样式,尚未与numpy统一。希望在tf3.x中情况会更好