用一维张量对二维张量进行子集

时间:2021-01-19 20:00:41

标签: python pytorch tensor

我想从二维张量的每一行中提取存储在另一个一维张量中的列。

import torch
test_tensor = tensor([1,-2,3], [-2,7,4]).float()
select_tensor = tensor([1,2])

所以在这个特定的例子中,我想获取第一行位置 1 中的元素(所以 -2)和第二行位置 2 中的元素(所以 4)。 我试过了:

test_tensor[:, select_tensor]

但这会为每一行选择位置 1 和 2 的元素。我怀疑这可能是我遗漏的一些非常简单的东西。

2 个答案:

答案 0 :(得分:1)

如果您正在寻找带有索引的解决方案,您还需要对 axis=0 进行索引,您可以使用 torch.arange 来实现:

>>> test_tensor = torch.tensor([[1,-2,3], [-2,7,4]])
>>> select_tensor = torch.tensor([1,2])

>>> test_tensor[torch.arange(len(select_tensor)), select_tensor]
tensor([-2,  4])

答案 1 :(得分:1)

您可以使用torch.gather

import torch
test_tensor = torch.tensor([[1,-2,3], [-2,7,4]]).float()
select_tensor = torch.tensor([1,2], dtype=torch.int64).view(-1,1) # number of dimension should match with the test tensor.
final_tensor = torch.gather(test_tensor, 1, select_tensor)
final_tensor

输出

tensor([[-2.],
        [ 4.]])

或者,使用 torch.view 来展平输出张量:final_tensor.view(-1) 会给你 tensor([-2., 4.])