在numpy中,我可以在以下内容中建立索引:
a = np.random.randn(2,2,3)
b = np.eye(2,2).astype(np.uint8)
c = np.eye(2,2).astype(np.uint8)
print(a)
print("diff")
print(a[b,c,:])
,其中a [b,c,:]是2 * 2的张量。
[[[-1.01338087 0.70149058 0.55268617]
[ 2.56941124 1.12720312 -0.07219555]]
[[-0.04084548 0.17018995 2.14229567]
[-0.68017558 -0.91788125 1.1719151 ]]]
diff
[[[-0.68017558 -0.91788125 1.1719151 ]
[-1.01338087 0.70149058 0.55268617]]
[[-1.01338087 0.70149058 0.55268617]
[-0.68017558 -0.91788125 1.1719151 ]]]
但是在Pytorch中,我无法像a[b,c,:]
那样进行索引。谁知道该怎么做。谢谢〜
答案 0 :(得分:0)
在PyTorch中建立索引几乎类似于numpy。
a = torch.randn(2, 2, 3)
b = torch.eye(2, 2, dtype=torch.long)
c = torch.eye(2, 2, dtype=torch.long)
print(a)
print(a[b, c, :])
tensor([[[ 1.2471, 1.6571, -2.0504],
[-1.7502, 0.5747, -0.3451]],
[[-0.4389, 0.4482, 0.7294],
[-1.3051, 0.6606, -0.6960]]])
tensor([[[-1.3051, 0.6606, -0.6960],
[ 1.2471, 1.6571, -2.0504]],
[[ 1.2471, 1.6571, -2.0504],
[-1.3051, 0.6606, -0.6960]]])