我发现tf.scatter_add的行为非常奇怪:我创建了一个tf.while_loop,它创建了一个包含在tf.Variable中的Tensor。
如果我不在循环外向变量添加内容,则tensorflow会导致错误,告诉我变量不可变。
这是一个MWE:
import tensorflow as tf
m = 25
batch_num = 32
num_bus = 50
C = tf.zeros((m, batch_num, num_bus, m),tf.float64)
C = tf.Variable(C)
c = tf.ones((batch_num, num_bus, m), tf.float64)
#C = tf.scatter_add(C,0,c)
k = tf.constant(1)
stop_cond = lambda k,C: k<m
def construct_C(k, C):
upd_c = c+1
C = tf.scatter_add(C,k,upd_c)
return k+1,C
k,C = tf.while_loop(stop_cond,construct_C, (k,C))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
C1 = sess.run(C)
此代码导致错误:TypeError: 'ScatterAdd' Op requires that input 'ref' be a mutable tensor (e.g.: a tf.Variable)
。但是,当我取消注释C = tf.scatter_add(C,0,c)
时,一切正常。
这是打算吗?我做错了什么?
答案 0 :(得分:1)
听起来像某些while_loop原语并不了解变量(相反,他们知道有关参数的Tensors)。这看起来像代码中的错误 - 请在github上提交一个问题。