我想从二维张量的每一行中提取存储在另一个一维张量中的列。
import torch
test_tensor = tensor([1,-2,3], [-2,7,4]).float()
select_tensor = tensor([1,2])
所以在这个特定的例子中,我想获取第一行位置 1 中的元素(所以 -2)和第二行位置 2 中的元素(所以 4)。 我试过了:
test_tensor[:, select_tensor]
但这会为每一行选择位置 1 和 2 的元素。我怀疑这可能是我遗漏的一些非常简单的东西。
答案 0 :(得分:1)
如果您正在寻找带有索引的解决方案,您还需要对 axis=0
进行索引,您可以使用 torch.arange
来实现:
>>> test_tensor = torch.tensor([[1,-2,3], [-2,7,4]])
>>> select_tensor = torch.tensor([1,2])
>>> test_tensor[torch.arange(len(select_tensor)), select_tensor]
tensor([-2, 4])
答案 1 :(得分:1)
您可以使用torch.gather
import torch
test_tensor = torch.tensor([[1,-2,3], [-2,7,4]]).float()
select_tensor = torch.tensor([1,2], dtype=torch.int64).view(-1,1) # number of dimension should match with the test tensor.
final_tensor = torch.gather(test_tensor, 1, select_tensor)
final_tensor
输出
tensor([[-2.],
[ 4.]])
或者,使用 torch.view
来展平输出张量:final_tensor.view(-1)
会给你 tensor([-2., 4.])