无法使用tf.gradients获取tf.gather_nd op的渐变吗?

时间:2018-12-07 07:31:53

标签: python tensorflow

我尝试获取tf.gather_nd op的渐变,但输出显示渐变为None。 Tensorflow已在https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/array_grad.py#L355中实现了“ def _GatherNdGrad(op,grad):”。但是,它不起作用。

params = tf.constant(np.array([[1,2], [3,4]]))
indices = tf.constant(np.array([[[0,0],[0,1]],[[1,1],[0,1]]]))
w = tf.constant(np.random.randn(4,3), dtype=tf.float32)
with tf.Session() as sess:
    out_1 = tf.gather_nd(params=params, indices=indices)
    print(sess.run(out_1))
    gradients = tf.gradients(ys=out_1, xs=indices)
    print(gradients)
    print(sess.run(gradients)
  

[[1 2]    [4 2]]

     

[无]

0 个答案:

没有答案