我想将tf.gather()
应用于给定参数张量和索引张量的所有行。
我可以在两个1D张量上应用tf.gather()
来提取1D张量:
# params == array([3, 8, 9, 7, 6])
# inds == array([1, 2, 3])
>>> tf.gather(params, inds).eval()
array([8, 9, 7])
现在如果我有两个2D张量,并希望逐行应用tf.gather()
怎么办?我想要这样的东西:
# params == array([[3, 8, 9, 7, 6],
# [6, 1, 7, 0, 7],
# [7, 4, 4, 5, 8]])
# inds == array([[1, 2, 3],
# [2, 3, 4],
# [0, 1, 2]])
>>> row_wise_gather(params, inds)
array([[8, 9, 7],
[7, 0, 7],
[7, 4, 4]]
我到目前为止最接近的是tf.gather()
使用axis=1
,产生3D张量,然后使用gather_nd()
索引结果:
>>> gathered3d = tf.gather(params, inds, axis=1)
# gathered3d == array([[[8, 9, 7],
# [9, 7, 6],
# [3, 8, 9]],
#
# [[1, 7, 0],
# [7, 0, 7],
# [6, 1, 7]],
#
# [[4, 4, 5],
# [4, 5, 8],
# [7, 4, 4]]])
>>> tf.gather_nd(gathered3d, [[0, 0], [1, 1], [2, 2]]).eval()
array([[8, 9, 7],
[7, 0, 7],
[7, 4, 4]])
(我会调用其他功能而不是提供字面值,但这不是重点,而不是问题)
这非常笨拙。有没有更有效的方法来做到这一点?
顺便说一句,我使用的索引总是逐个增加的值;每行只有不同的开始和结束值。这可能会使问题更容易。