我有一个尺寸为(Nx128x3)的3D张量参数,其中N是批量大小。
我还有一个索引张量索引=(Nx16),它将指标保存在A的第二维中。我想获取给定索引的整行,以便结果为(Nx16x3)。
目前我正在使用以下代码
gathered = tf.reshape(tf.gather(tf.reshape(params,-1,int(params.shape[2])]),tf.reshape(indices,[-1,])),[-1,int(indices.shape[1]),int(params.shape[2])])
有没有办法用gather_nd写这个?
- 我目前的完整代码:
import tensorflow as tf
import numpy as np
N = 10
params = tf.constant(np.random.randn(N, 128, 3), dtype=tf.float32)
indices = tf.constant(np.random.randint(0, 128, [N,16]), dtype=tf.int32)
gathered = tf.reshape(tf.gather(tf.reshape(params,[-1,int(params.shape[2])]),tf.reshape(indices,[-1,])),[-1,int(indices.shape[1]),int(params.shape[2])])
with tf.Session() as sess:
result = sess.run(gathered)
print('params: ',params.shape)
print('ind: ',indices.shape)
print('result: ',result.shape)