使用tf.gather基于行数的另一个张量按行提取张量(第一维)

时间:2019-03-01 16:17:02

标签: python tensorflow tensor

我有两个张量,尺寸分别为A:[B,3000,3]C:[B,4000]。我想使用tf.gather()来使用张量C中的每一行作为索引,并使用张量A中的每一行作为参数,以获得大小为[B,4000,3]的结果。

下面是一个使它更易于理解的示例:假设我有张量为

  

A = [[1,2,3],[4,5,6],[7,8,9]],

     

C = [0,2,1,2,1],

     

result = [[1,2,3],[7,8,9],[4,5,6],[7,8,9],[4,5,6]],

通过使用tf.gather(A,C)。应用于尺寸小于3的张量时,一切都很好。

但是在以描述为开头的情况下,通过应用tf.gather(A,C,axis=1),结果张量的形状为

  

[B,B,4000,3]

似乎tf.gather()只是对张量C中的每个元素做了工作,作为索引来收集张量A中的元素。我正在考虑的唯一解决方案是使用for循环,但是通过使用tf.gather(A[i,...],C[i,...])获得张量的正确大小会极大地降低计算能力

  

[B,4000,3]

因此,是否有任何功能可以类似地完成此任务?

1 个答案:

答案 0 :(得分:0)

您需要使用tf.gather_nd

Hello
World