如何在Tensorflow中用二维张量索引三维张量?

时间:2018-03-12 00:17:47

标签: python tensorflow

我正在尝试使用二维张量来索引Tensorflow中的三维张量。例如,我的x形状为[2, 3, 4]

[[[ 0,  1,  2,  3],
  [ 4,  5,  6,  7],
  [ 8,  9, 10, 11]],

 [[12, 13, 14, 15],
  [16, 17, 18, 19],
  [20, 21, 22, 23]]]

我想用另一个y的张量[2, 3]对其进行索引,其中y的每个元素都索引x的最后一个维度。例如,如果我们有y喜欢:

[[0, 2, 3],
 [1, 0, 2]]

输出应为[2, 3]形状:

[[0, 6, 11],
 [13, 16, 22]]

1 个答案:

答案 0 :(得分:1)

使用tf.meshgrid创建索引,然后使用tf.gather_nd提取元素:

# create a list of indices for except the last axis
idx_except_last = tf.meshgrid(*[tf.range(s) for s in x.shape[:-1]], indexing='ij')

# concatenate with last axis indices
idx = tf.stack(idx_except_last + [y], axis=-1)

# gather elements based on the indices
tf.gather_nd(x, idx).eval()

# array([[ 0,  6, 11],
#        [13, 16, 22]])