Tensorflow AssertionError“渐变列表应该已经汇总了”

时间:2018-02-02 13:27:29

标签: tensorflow

我的函数(new Query()) ->select('*') ->from('stock_in') ->where("user_id = '$request_user_id'") ->andWhere("product_id = '$product_id'") ->andWhere("remaining_quantity > '0'") ->one(); 在内部使用了一些ftf.while_loop来计算值tf.gradients。像这样的东西

y = f(x)

有几百行代码。但我相信那些是重点......

现在当我评估def f( x ): ... def body( g, x ): # Compute the gradient here grad = tf.gradients( g, x )[0] ... return ... return tf.while_loop( cond, body, parallel_iterations=1 ) 时,我得到了我期望的值......

f(x)

但是,当我尝试评估y = known output of f(x) with tf.Session() as sess: fx = f(x) print("Error = ", y - sess.run(fx, feed_dict)) # Prints 0 相对于f(x)的渐变时,即

x

我收到错误

grads = tf.gradients( fx, x )[0]

以下是完整的跟踪:

AssertionError: gradients list should have been aggregated by now.

有人可以概述这个错误的可能原因吗?我不知道从哪里开始寻找这个问题......

一些观察结果:

  1. 请注意,我已将while循环的并行迭代设置为1.这 应该意味着由于从多个线程读取和写入而没有错误。

  2. 如果我放弃while循环,只让 File "C:/Dropbox/bob/tester.py", line 174, in <module> grads = tf.gradients(y, x)[0] File "C:\Anaconda36\lib\site-packages\tensorflow\python\ops\gradients_impl.py", line 649, in gradients return [_GetGrad(grads, x) for x in xs] File "C:\Anaconda36\lib\site-packages\tensorflow\python\ops\gradients_impl.py", line 649, in <listcomp> return [_GetGrad(grads, x) for x in xs] File "C:\Anaconda36\lib\site-packages\tensorflow\python\ops\gradients_impl.py", line 727, in _GetGrad "gradients list should have been aggregated by now.") AssertionError: gradients list should have been aggregated by now. 返回f,则代码会运行:

    body()
  3. 显然,输出不正确,但至少计算了渐变。

1 个答案:

答案 0 :(得分:0)

我遇到了类似的问题。我注意到的一些模式:

  • 如果x中使用的tf.gradientsbody中需要维度广播的方式使用,则会出现此错误。如果我将其更改为不需要广播的广告,则tf.gradients会返回[None]。我没有对此进行广泛测试,因此这种模式在所有示例中可能都不一致。
  • 这两种情况(返回[None]并提出此断言错误)可以通过区分tf.identity(y)而不仅仅是y来解决:grads = tf.gradients(tf.identity(y), xs)我完全不知道为什么会这样做