如何使用tf.scatter_add增加张量流中的矩阵元素?

时间:2016-07-19 14:21:38

标签: python tensorflow

tf.scatter_add适用于1d(形状1)张量:

> S = tf.Variable(tf.constant([1,2,3,4]))
> sess.run(tf.initialize_all_variables())
> sess.run(tf.scatter_add(S, [0], [10]))

array([11,  2,  3,  4], dtype=int32)

> sess.run(tf.scatter_add(S, [0, 1], [10, 100]))

array([ 21, 102,   3,   4], dtype=int32)

但是我怎么能增加,比如说<0,0]

的元素
M = tf.Variable(tf.constant([[1,2], [3,4]]))

使它[[2,2],[3,4]] 使用tf.scatter_add?

official documentation有点神秘。我尝试了不同的arg值,比如说

> sess.run(tf.scatter_add(M, [[0, 0]], [1]))
*** ValueError: Shapes (1,) and (1, 2, 2) are not compatible

并没有成功。

顺便说一句,就我而言,M非常大并且动态调整大小。 因此,向M添加等于1的元素矩阵的零但不是这种情况。

1 个答案:

答案 0 :(得分:3)

tf.scatter_add更新张量的片段,不能更新单个系数。例如,它可以一次更新矩阵的整行。

此外,updates tf.scatter_add参数的形状取决于其indices参数的形状。当ref参数是形状为(M, N)的矩阵时,则

  • 如果indices是标量i,则updates应为形状为(N)的向量。
  • 如果indices是形状为[i1, i2, .. ik]的向量(k),则updates的形状应为(k, N)

在您的情况下,您只需将[1, 0]添加到M的第一行,如下所示即可获得所需效果:

sess.run(tf.scatter_add(M, 0, [1, 0]))
array([[2, 2],
   [3, 4]], dtype=int32)