将tf.gather应用于两个张量的所有行

时间:2018-01-10 06:02:34

标签: python tensorflow

我想将tf.gather()应用于给定参数张量和索引张量的所有行。

我可以在两个1D张量上应用tf.gather()来提取1D张量:

# params == array([3, 8, 9, 7, 6])
# inds == array([1, 2, 3])

>>> tf.gather(params, inds).eval()
array([8, 9, 7])

现在如果我有两个2D张量,并希望逐行应用tf.gather()怎么办?我想要这样的东西:

# params == array([[3, 8, 9, 7, 6],
#                  [6, 1, 7, 0, 7],
#                  [7, 4, 4, 5, 8]])

# inds == array([[1, 2, 3],
#                [2, 3, 4],
#                [0, 1, 2]])

>>> row_wise_gather(params, inds)
array([[8, 9, 7],
       [7, 0, 7],
       [7, 4, 4]]

我到目前为止最接近的是tf.gather()使用axis=1,产生3D张量,然后使用gather_nd()索引结果:

>>> gathered3d = tf.gather(params, inds, axis=1)
# gathered3d == array([[[8, 9, 7],
#                       [9, 7, 6],
#                       [3, 8, 9]],
#
#                      [[1, 7, 0],
#                       [7, 0, 7],
#                       [6, 1, 7]],
#
#                      [[4, 4, 5],
#                       [4, 5, 8],
#                       [7, 4, 4]]])

>>> tf.gather_nd(gathered3d, [[0, 0], [1, 1], [2, 2]]).eval()
array([[8, 9, 7],
       [7, 0, 7],
       [7, 4, 4]])

(我会调用其他功能而不是提供字面值,但这不是重点,而不是问题)

这非常笨拙。有没有更有效的方法来做到这一点?

顺便说一句,我使用的索引总是逐个增加的值;每行只有不同的开始和结束值。这可能会使问题更容易。

0 个答案:

没有答案