import numpy as np
import torch
a = torch.zeros(5)
b = torch.tensor(tuple((0,1,0,1,0)),dtype=torch.uint8)
c= torch.tensor([7.,9.])
print(a[b].size())
a[b]=c
print(a)
torch.Size([2])
张量([0.,7.,0.,9.,0。])
我正在努力了解其工作原理。我最初以为上面的代码使用Fancy索引,但是我意识到从 c 张量中获取的值将与标记为1的索引相对应地复制。此外,如果我未指定 b < / strong>作为 uint8 ,则以上代码将无法正常工作。有人可以解释一下上面代码的机制吗?
答案 0 :(得分:2)
使用数组建立索引与numpy和我所知道的大多数其他向量化数学包中的工作原理相同。有两种情况:
当b
的类型为uint8
(认为布尔值,pytorch不能将bool
与uint8
区分)时,a[b]
是1 -d数组,其中包含a
(a[i]
)中对应值b
(b[i]
)的子集。这些值是原始a
的别名,因此,如果您对其进行修改,它们的对应位置也会随之更改。
可用于索引的另一种类型是int64
的数组,在这种情况下a[b]
创建形状为(*b.shape, *a.shape[1:])
的数组。它的结构就像b
替换了b[i]
(a[i]
)的每个元素。换句话说,通过指定应从a
的哪个索引中提取数据来创建新数组。同样,这些值又是原始a
的别名,因此,如果您修改a[b]
,则每个a[b[i]]
的{{1}}的值都会改变。 this问题中显示了一个示例用例。
在integer array indexing和boolean array indexing中为numpy解释了这两种模式,对于后者,您必须记住pytorch使用i
代替uint8
。
此外,如果您的目标是将数据从一个张量复制到另一个张量,则必须记住,bool
之类的操作是就地操作(a[ixs] = b[ixs]
已就地修改),我不能与autograd一起玩。如果要进行位置遮蔽,请使用torch.where
。 this答案中显示了一个示例用例。