我想知道TensorFlow中是否可以使用gather_nd
或类似方法来完成以下操作。
我有两个张量:
values
,形状为[128, 100]
,indices
,形状为[128, 3]
,其中indices
的每一行都包含沿values
的第二维的索引(对于同一行)。我想使用values
为indices
编制索引。例如,我想要执行此操作的东西(使用松散符号表示张量):
values = [[0, 0, 0, 1, 1, 0, 1],
[1, 1, 0, 0, 1, 0, 0]]
indices = [[2, 3, 6],
[0, 2, 3]]
batched_gather(values, indices) = [[0, 1, 1], [1, 0, 0]]
此操作将遍历values
和indices
的每一行,并使用values
行在indices
行上进行收集。
在TensorFlow中有一种简单的方法吗?
谢谢!
答案 0 :(得分:1)
不确定这是否等同于“简单”,但是您可以为此使用def batched_gather(values, indices):
row_indices = tf.range(0, tf.shape(values)[0])[:, tf.newaxis]
row_indices = tf.tile(row_indices, [1, tf.shape(indices)[-1]])
indices = tf.stack([row_indices, indices], axis=-1)
return tf.gather_nd(values, indices)
:
[0, 1]
说明:其构想是构建索引向量,例如indices
,意思是“第0行和第1列中的值”。
该功能的tf.shape
自变量中已经给出了列索引。
行索引是从0到例如128(在您的示例中),但根据每行的列索引数重复(平铺)(在您的示例中为3;如果此数字固定,则可以对其进行硬编码,而不使用array([[[0, 2],
[0, 3],
[0, 6]],
[[1, 0],
[1, 2],
[1, 3]]])
)。
然后,将行和列索引堆叠起来以生成索引向量。在您的示例中,结果索引为
gather_nd
和{{1}}会产生所需的结果。