在张量流中索引3D批量张量

时间:2018-06-02 10:46:21

标签: python tensorflow indexing slice

我有一个尺寸为(Nx128x3)的3D张量参数,其中N是批量大小。

我还有一个索引张量索引=(Nx16),它将指标保存在A的第二维中。我想获取给定索引的整行,以便结果为(Nx16x3)。

目前我正在使用以下代码

gathered = tf.reshape(tf.gather(tf.reshape(params,-1,int(params.shape[2])]),tf.reshape(indices,[-1,])),[-1,int(indices.shape[1]),int(params.shape[2])])

有没有办法用gather_nd写这个?

- 我目前的完整代码:

import tensorflow as tf
import numpy as np

N = 10
params = tf.constant(np.random.randn(N, 128, 3), dtype=tf.float32)
indices = tf.constant(np.random.randint(0, 128, [N,16]), dtype=tf.int32)

gathered = tf.reshape(tf.gather(tf.reshape(params,[-1,int(params.shape[2])]),tf.reshape(indices,[-1,])),[-1,int(indices.shape[1]),int(params.shape[2])])

with tf.Session() as sess:
    result = sess.run(gathered)
print('params: ',params.shape)
print('ind: ',indices.shape)
print('result: ',result.shape)

0 个答案:

没有答案