我尝试获取tf.gather_nd op的渐变,但输出显示渐变为None。 Tensorflow已在https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/array_grad.py#L355中实现了“ def _GatherNdGrad(op,grad):”。但是,它不起作用。
params = tf.constant(np.array([[1,2], [3,4]]))
indices = tf.constant(np.array([[[0,0],[0,1]],[[1,1],[0,1]]]))
w = tf.constant(np.random.randn(4,3), dtype=tf.float32)
with tf.Session() as sess:
out_1 = tf.gather_nd(params=params, indices=indices)
print(sess.run(out_1))
gradients = tf.gradients(ys=out_1, xs=indices)
print(gradients)
print(sess.run(gradients)
[[1 2] [4 2]]
[无]