基于索引张量从张量中选择值

时间:2021-06-12 15:49:35

标签: tensorflow tensorflow2.0

我有两个矩阵。矩阵 A 包含一些值,矩阵 B 包含索引。矩阵 A 和 B 的形状分别为 (batch, values) 和 (batch, index)。

我的目标是根据矩阵 B 沿批次维度的索引从矩阵 A 中选择值。

例如:

# Matrix A
<tf.Tensor: shape=(2, 5), dtype=float32, numpy=
array([[0., 1., 2., 3., 4.],
       [5., 6., 7., 8., 9.]], dtype=float32)>

# Matrix B
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[0, 1],
       [1, 2]], dtype=int32)>

# Expected Result
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[0., 1.],
       [6., 7.]], dtype=int32)>

如何在 Tensorflow 中实现这一点?

非常感谢!

1 个答案:

答案 0 :(得分:1)

您可以使用 tf.gather 函数实现这一点。

mat_a = tf.constant([[0., 1., 2., 3., 4.],
                     [5., 6., 7., 8., 9.]])
mat_b = tf.constant([[0, 1], [1, 2]])

out = tf.gather(mat_a, mat_b, batch_dims=1)
out.numpy()
array([[0., 1.],
       [6., 7.]], dtype=float32)