我正在尝试使用二维张量来索引Tensorflow中的三维张量。例如,我的x
形状为[2, 3, 4]
:
[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]]
我想用另一个y
的张量[2, 3]
对其进行索引,其中y
的每个元素都索引x
的最后一个维度。例如,如果我们有y
喜欢:
[[0, 2, 3],
[1, 0, 2]]
输出应为[2, 3]
形状:
[[0, 6, 11],
[13, 16, 22]]
答案 0 :(得分:1)
使用tf.meshgrid
创建索引,然后使用tf.gather_nd
提取元素:
# create a list of indices for except the last axis
idx_except_last = tf.meshgrid(*[tf.range(s) for s in x.shape[:-1]], indexing='ij')
# concatenate with last axis indices
idx = tf.stack(idx_except_last + [y], axis=-1)
# gather elements based on the indices
tf.gather_nd(x, idx).eval()
# array([[ 0, 6, 11],
# [13, 16, 22]])