我有两个矩阵。矩阵 A 包含一些值,矩阵 B 包含索引。矩阵 A 和 B 的形状分别为 (batch, values) 和 (batch, index)。
我的目标是根据矩阵 B 沿批次维度的索引从矩阵 A 中选择值。
例如:
# Matrix A
<tf.Tensor: shape=(2, 5), dtype=float32, numpy=
array([[0., 1., 2., 3., 4.],
[5., 6., 7., 8., 9.]], dtype=float32)>
# Matrix B
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[0, 1],
[1, 2]], dtype=int32)>
# Expected Result
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[0., 1.],
[6., 7.]], dtype=int32)>
如何在 Tensorflow 中实现这一点?
非常感谢!
答案 0 :(得分:1)
您可以使用 tf.gather
函数实现这一点。
mat_a = tf.constant([[0., 1., 2., 3., 4.],
[5., 6., 7., 8., 9.]])
mat_b = tf.constant([[0, 1], [1, 2]])
out = tf.gather(mat_a, mat_b, batch_dims=1)
out.numpy()
array([[0., 1.],
[6., 7.]], dtype=float32)