鉴于a = torch.randn(3, 2, 4, 5)
,我如何选择(2, :, 0, :), (1, :, 1, :), (2, :, 2, :), (0, :, 3, :)
之类的子张量(结果张量为(2, 4, 5)
或(4, 2, 5)
?
虽然a[2, :, 0, :]
给出了
0.5580 -0.0337 1.0048 -0.5044 0.6784
-1.6117 1.0084 1.1886 0.1278 0.3739
[torch.FloatTensor of size 2x5]
然而,a[[2, 1, 2, 0], :, [0, 1, 2, 3], :]
给出了
TypeError:对张量执行基本索引,遇到使用list类型的对象索引dim 0的错误。唯一支持的类型是整数,切片,numpy标量,或者如果使用torch.LongTensor或torch.ByteTensor进行索引,则只能传递一个Tensor。
虽然numpy
成功返回(4, 2, 5)
张量。
答案 0 :(得分:0)
它对你有用吗?
import torch
a = torch.randn(3, 2, 4, 5)
print(a.size())
b = [a[2, :, 0, :], a[1, :, 1, :], a[2, :, 2, :], a[0, :, 3, :]]
b = torch.stack(b, 0)
print(b.size()) # torch.Size([4, 2, 5])