如何在PyTorch中选择二维索引?

时间:2017-11-19 06:05:47

标签: numpy pytorch

鉴于a = torch.randn(3, 2, 4, 5),我如何选择(2, :, 0, :), (1, :, 1, :), (2, :, 2, :), (0, :, 3, :)之类的子张量(结果张量为(2, 4, 5)(4, 2, 5)

虽然a[2, :, 0, :]给出了

 0.5580 -0.0337  1.0048 -0.5044  0.6784
-1.6117  1.0084  1.1886  0.1278  0.3739
[torch.FloatTensor of size 2x5]

然而,a[[2, 1, 2, 0], :, [0, 1, 2, 3], :]给出了

  

TypeError:对张量执行基本索引,遇到使用list类型的对象索引dim 0的错误。唯一支持的类型是整数,切片,numpy标量,或者如果使用torch.LongTensor或torch.ByteTensor进行索引,则只能传递一个Tensor。

虽然numpy成功返回(4, 2, 5)张量。

1 个答案:

答案 0 :(得分:0)

它对你有用吗?

import torch

a = torch.randn(3, 2, 4, 5)
print(a.size())

b = [a[2, :, 0, :], a[1, :, 1, :], a[2, :, 2, :], a[0, :, 3, :]]
b = torch.stack(b, 0)

print(b.size()) # torch.Size([4, 2, 5])