我需要从二维张量中收集元素。例如, 二维张量A:
[
[1,2,3,4],
[5,6,7,8]
]
和二维张量B:
[
[0,2],
[1,3]
]
B指定所需元素的索引。 A和B是行对应的,这意味着在A的第一行中访问索引0和2的元素,在A的第二行中访问索引1和3的元素。因此,结果是2-D张量:
[
[1,3],
[6,8]
]
由于tf.gather()在1-D中做类似的事情,我想tf.gather()和tf.map_fn()的组合可以达到上述目的。还有其他更有效的方法吗?
谢谢。