如何进行非累积张量流scatter_add?

时间:2017-08-16 10:57:38

标签: tensorflow

作为学习tensorflow的方法的一部分,我正在转换一些现有的矩阵处理逻辑。其中一个步骤是分散操作,例如下面示例中使用scatter_add的分散操作。我对这个例子的问题在于,每次评估操作时,它都会累积在前一个结果之上。使用如下所示的3次run()调用,打印的结果为:

[[8 12 8]...]
[[16 24 16]...]
[[24 36 24]...]

而我想要的每次都是[[8 12 8]...]indices向量包含重复项,updates中的相应元素需要加在一起,而不是scattered中已存在的现有值。

tensorflow文档中的分散操作似乎都不是我想要的。是否有适当的操作使用?如果没有,达到我需要的最佳方式是什么?

import tensorflow as tf

indices = tf.constant([0, 1, 0, 1, 0, 1, 0, 1], tf.int32)

updates = tf.constant([
            [1., 2., 3., 4.],
            [2., 3., 4., 1.],
            [3., 4., 1., 2.],
            [4., 1., 2., 3.],
            [1., 2., 3., 4.],
            [2., 3., 4., 1.],
            [3., 4., 1., 2.],
            [4., 1., 2., 3.]], tf.float32)

scattered = tf.Variable([
            [0., 0., 0., 0.,],
            [0., 0., 0., 0.,]], tf.float32)

# Requirement:
# scattered[i, j] = Sum of updates[k, j] where indices[k] == i
#
# i.e.
#   scattered_data = [
#     [1+3+1+3, 2+4+2+4, 3+1+3+1, 4+2+4+2], 
#     [2+4+2+4, 3+1+3+1, 4+2+4+2, 1+3+1+3]]
#   == [
#     [ 8, 12,  8, 12],
#     [12,  8, 12,  8]]

scattered = tf.scatter_add(scattered, indices, updates, use_locking=True, name='scattered')
scattered_print = tf.Print(scattered, [scattered])

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sess.run(scattered_print)
# Printout: [[8 12 8]...]
sess.run(scattered_print)
# Printout: [[16 24 16]...]
sess.run(scattered_print)
# Printout: [[24 36 24]...]
sess.close()

2 个答案:

答案 0 :(得分:0)

Scatter_add更新变量引用。所以你可以做一些如下所示的事情:

tf.matmul(tf.cast(tf.concat([indices[tf.newaxis,...], 1-indices[tf.newaxis,...]], axis=0), tf.float32),updates)

答案 1 :(得分:0)

scatter_add调用的以下修改似乎可以使事情按预期工作:

with tf.control_dependencies([scattered.initializer]):
    scattered = tf.scatter_add(scattered, indices, updates, use_locking=True, name='scattered')

基本原理是,由于我使用零初始化变量,因此强制初始化程序在每次scatter_add操作之前重新运行将清除它并避免累积更新。

这对我来说似乎有点笨拙 - 我原本预计会有一个单一呼叫解决方案。而且我不确定有多少不必要的内存分配和释放可能会发生,但我认为这是后来要考虑的事情。