我有一个张量" a"形状[72,3,65]和指数" b"来自tf.top_k
的形状[72,3]。我想做的是:
c = tf.gather_nd(a, b)
c
有形状[72],但我希望它有形状[72,3,65]。我该怎么做?
编辑:
我认为我找到了一种使用tf.gather来实现这一目标的方法,但是我怎样才能使用gather_nd来简化事情,或者我是否误解了它的目的?
a = tf.reshape(a, [dim_1_size*dim_2_size, -1])
b = tf.reshape(b, [-1])
c = tf.gather(a, b)
c = tf.reshape(c, old_a_shape)