我无法将tf.scatter_nd_add()
应用于2D张量。文档有点不清楚,并不包含稀疏更新的示例,但仅适用于完整切片更新。
我的情况如下:
updates
- 2D张量形状[None, 6]
indices
- 2D张量形状[None, 6]
ref
- 2D形状为[None, 6]
保证updates
,indices
和ref
的第一个维度始终相等,但该维度的大小可能会有所不同。我想要执行的更新看起来像
for i, j:
k = indices[i][j]
ref[i][k] += updates[i][j]
请注意indices
包含重复项。 tf.scatter_nd_add(ref, indices, updates)
抱怨形状不匹配,我无法弄清楚如何重新构建张量以执行更新。
答案 0 :(得分:1)
我明白了。 indices
中的每个2D条目必须实际指定将在ref
中更新的绝对位置。这意味着indices
必须是3D,然后非矢量化更新如下所示:
for i, j:
r, k = indices[i][j]
ref[r][k] += updates[i][j]
在上述问题中,r
总是等于i
。
这是一个具有不同形状的完整Tensorflow实现。为清楚起见,在以下示例中,col_indices
对应于原始问题中的indices
:
import tensorflow as tf
import numpy as np
updates = tf.placeholder(dtype=tf.float32, shape=[None, 6])
col_indices = tf.placeholder(dtype=tf.int32, shape=[None, 6])
row_indices = tf.cumsum(tf.ones_like(col_indices), axis=0, exclusive=True)
indices = tf.concat([tf.expand_dims(row_indices, axis=-1),
tf.expand_dims(col_indices, axis=-1)], axis=-1)
tmp_var = tf.Variable(0, trainable=False, dtype=tf.float32, validate_shape=False)
ref = tf.assign(tmp_var, tf.zeros_like(updates), validate_shape=False)
# This makes sure that ref is always 0 before scatter_nd_add() runs
with tf.control_dependencies([target_var]):
result = tf.scatter_nd_add(ref, indices, updates)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# Create example input data
np_input = np.arange(0, 6, 1, dtype=np.int32)
np_input = np.tile(np_input[None,:], [10, 1])
res = sess.run(result, feed_dict={updates: np_input, col_indices: np_input})
print(res)