我有两个张量,尺寸分别为A:[B,3000,3]
和C:[B,4000]
。我想使用tf.gather()
来使用张量C中的每一行作为索引,并使用张量A中的每一行作为参数,以获得大小为[B,4000,3]
的结果。
下面是一个使它更易于理解的示例:假设我有张量为
A = [[1,2,3],[4,5,6],[7,8,9]],
C = [0,2,1,2,1],
result = [[1,2,3],[7,8,9],[4,5,6],[7,8,9],[4,5,6]],
通过使用tf.gather(A,C)
。应用于尺寸小于3的张量时,一切都很好。
但是在以描述为开头的情况下,通过应用tf.gather(A,C,axis=1)
,结果张量的形状为
[B,B,4000,3]
似乎tf.gather()
只是对张量C中的每个元素做了工作,作为索引来收集张量A中的元素。我正在考虑的唯一解决方案是使用for
循环,但是通过使用tf.gather(A[i,...],C[i,...])
获得张量的正确大小会极大地降低计算能力
[B,4000,3]
因此,是否有任何功能可以类似地完成此任务?