使用索引张量选择张量的第二个暗度

时间:2019-02-21 05:08:37

标签: pytorch

我有一个2D张量和一个索引张量。 2D张量具有批处理尺寸和具有3个值的尺寸。我有一个索引张量,可以从3个值中精确选择1个元素。产生仅包含索引张量中的元素的切片的“最佳”方法是什么?

t = torch.tensor([[1,2,3], [4,5,6], [7,8,9]])
t = tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

i = torch.tensor([0,0,1], dtype=torch.int64)
tensor([0, 0, 1])

预期输出...

tensor([1, 4, 8])

1 个答案:

答案 0 :(得分:2)

答案示例如下。

import torch

t = torch.tensor([[1,2,3], [4,5,6], [7,8,9]])
col_i = [0, 0, 1]
row_i = range(3)
print(t[row_i, col_i])
# tensor([1, 4, 8])