在更新点计算梯度时,Tensorflow和control_dependencies中的竞争条件?

时间:2017-12-08 02:51:50

标签: python tensorflow gradients

我现在面临一个复杂的问题。

想象一下,你有变量 p v ,函数 f(p) G(p) 你想做两个步骤

  1. p = G(p)
  2. v = f'(p)即。首先更新p然后在新更新的p中计算梯度并将其保存到v
  3. 一个人可以运行两次sess.run(),但我们可以一次运行吗? 我有一个代码

    import tensorflow as tf
    sess = tf.InteractiveSession()
    import numpy as np
    shape = [1000,1000]
    a = tf.Variable(tf.truncated_normal(shape,stddev=0.01))
    b = tf.Variable(tf.truncated_normal(shape,stddev=0.01))
    c = tf.Variable(tf.truncated_normal(shape,stddev=0.01))
    params = [a,b,c]
    
    fun =  tf.reduce_sum( tf.multiply(  tf.multiply(a,c),b ) )
    
    o = [ tf.Variable(tf.zeros(shape)) for i in xrange(3)]
    sess.run(tf.global_variables_initializer())
    
    grad = tf.gradients(fun,params)
    update_a = tf.assign(a, b + c).op
    with tf.control_dependencies([update_a]):
      update_b = tf.assign(b, c + a).op
    
    with tf.control_dependencies([update_b]):
      update_c = tf.assign(c, a  + b ).op
    
    with tf.control_dependencies([update_c]):
      df = tf.gradients(fun,params,gate_gradients=True)
      updates = tf.group(  *[ oi.assign(dfi ) for oi, dfi in zip(o,df)]  )
    
    def computeTotalSum(v):
        return [  np.sum(np.abs(j)) for j in sess.run(v)   ]
    
    print computeTotalSum(params)
    print computeTotalSum(o)
    print computeTotalSum(grad) 
    
    sess.run(updates)
    
    print computeTotalSum(params)
    print computeTotalSum(o)
    print computeTotalSum(grad) 
    

    运行后,我得到以下答案

    [7229.5298, 7229.4917, 7227.6387]
    [0.0, 0.0, 0.0]
    [52.223175, 52.195709, 52.314934]
    [10073.588, 16026.059, 25751.27]
    [619.71185, 52.195709, 234.3277]
    [619.71185, 388.07501, 234.3277]
    

    然而,我希望最后两行是相同的!

0 个答案:

没有答案