scatter_nd_add()示例,用于TensroFlow中的稀疏加法

时间:2017-10-25 19:09:10

标签: python tensorflow sparse-matrix

我无法将tf.scatter_nd_add()应用于2D张量。文档有点不清楚,并不包含稀疏更新的示例,但仅适用于完整切片更新。

我的情况如下:

  • updates - 2D张量形状[None, 6]
  • indices - 2D张量形状[None, 6]
  • ref - 2D形状为[None, 6]
  • 的零变量

保证updatesindicesref的第一个维度始终相等,但该维度的大小可能会有所不同。我想要执行的更新看起来像

for i, j:
    k = indices[i][j]
    ref[i][k] += updates[i][j] 

请注意indices包含重复项。 tf.scatter_nd_add(ref, indices, updates)抱怨形状不匹配,我无法弄清楚如何重新构建张量以执行更新。

1 个答案:

答案 0 :(得分:1)

我明白了。 indices中的每个2D条目必须实际指定将在ref中更新的绝对位置。这意味着indices必须是3D,然后非矢量化更新如下所示:

for i, j:
    r, k = indices[i][j]
    ref[r][k] += updates[i][j]

在上述问题中,r总是等于i

这是一个具有不同形状的完整Tensorflow实现。为清楚起见,在以下示例中,col_indices对应于原始问题中的indices

import tensorflow as tf
import numpy as np

updates     = tf.placeholder(dtype=tf.float32,  shape=[None, 6])
col_indices = tf.placeholder(dtype=tf.int32,    shape=[None, 6])
row_indices = tf.cumsum(tf.ones_like(col_indices), axis=0, exclusive=True)
indices     = tf.concat([tf.expand_dims(row_indices, axis=-1), 
                         tf.expand_dims(col_indices, axis=-1)], axis=-1)

tmp_var     = tf.Variable(0, trainable=False, dtype=tf.float32, validate_shape=False)
ref         = tf.assign(tmp_var, tf.zeros_like(updates), validate_shape=False)
# This makes sure that ref is always 0 before scatter_nd_add() runs
with tf.control_dependencies([target_var]):
  result = tf.scatter_nd_add(ref, indices, updates)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Create example input data 
np_input = np.arange(0, 6, 1, dtype=np.int32)
np_input = np.tile(np_input[None,:], [10, 1])

res = sess.run(result, feed_dict={updates: np_input, col_indices: np_input})
print(res)