我有一个2D张量和一个索引张量。 2D张量具有批处理尺寸和具有3个值的尺寸。我有一个索引张量,可以从3个值中精确选择1个元素。产生仅包含索引张量中的元素的切片的“最佳”方法是什么?
t = torch.tensor([[1,2,3], [4,5,6], [7,8,9]])
t = tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
i = torch.tensor([0,0,1], dtype=torch.int64)
tensor([0, 0, 1])
预期输出...
tensor([1, 4, 8])
答案 0 :(得分:2)
答案示例如下。
import torch
t = torch.tensor([[1,2,3], [4,5,6], [7,8,9]])
col_i = [0, 0, 1]
row_i = range(3)
print(t[row_i, col_i])
# tensor([1, 4, 8])