tf.scatter_add导致循环错误

时间:2017-10-25 14:34:40

标签: python tensorflow

我发现tf.scatter_add的行为非常奇怪:我创建了一个tf.while_loop,它创建了一个包含在tf.Variable中的Tensor。

如果我不在循环外向变量添加内容,则tensorflow会导致错误,告诉我变量不可变。

这是一个MWE:

import tensorflow as tf        

m = 25
batch_num = 32
num_bus = 50

C = tf.zeros((m, batch_num, num_bus, m),tf.float64)
C = tf.Variable(C)

c = tf.ones((batch_num, num_bus, m), tf.float64)
#C = tf.scatter_add(C,0,c)

k = tf.constant(1)

stop_cond = lambda k,C: k<m

def construct_C(k, C):
    upd_c = c+1
    C = tf.scatter_add(C,k,upd_c)
    return k+1,C

k,C = tf.while_loop(stop_cond,construct_C, (k,C))

sess = tf.Session()
sess.run(tf.global_variables_initializer())
C1 = sess.run(C)

此代码导致错误:TypeError: 'ScatterAdd' Op requires that input 'ref' be a mutable tensor (e.g.: a tf.Variable)。但是,当我取消注释C = tf.scatter_add(C,0,c)时,一切正常。

这是打算吗?我做错了什么?

1 个答案:

答案 0 :(得分:1)

听起来像某些while_loop原语并不了解变量(相反,他们知道有关参数的Tensors)。这看起来像代码中的错误 - 请在github上提交一个问题。