使用位掩码将数据从一个张量复制到另一个张量

时间:2018-12-17 11:57:16

标签: pytorch

import numpy as np
import torch
a = torch.zeros(5)
b = torch.tensor(tuple((0,1,0,1,0)),dtype=torch.uint8)
c= torch.tensor([7.,9.])
print(a[b].size())
a[b]=c
print(a)
  

torch.Size([2])
张量([0.,7.,0.,9.,0。])

我正在努力了解其工作原理。我最初以为上面的代码使用Fancy索引,但是我意识到从 c 张量中获取的值将与标记为1的索引相对应地复制。此外,如果我未指定 b < / strong>作为 uint8 ,则以上代码将无法正常工作。有人可以解释一下上面代码的机制吗?

1 个答案:

答案 0 :(得分:2)

使用数组建立索引与numpy和我所知道的大多数其他向量化数学包中的工作原理相同。有两种情况:

  1. b的类型为uint8(认为布尔值,pytorch不能将booluint8区分)时,a[b]是1 -d数组,其中包含aa[i])中对应值bb[i])的子集。这些值是原始a的别名,因此,如果您对其进行修改,它们的对应位置也会随之更改。

  2. 可用于索引的另一种类型是int64的数组,在这种情况下a[b]创建形状为(*b.shape, *a.shape[1:])的数组。它的结构就像b替换了b[i]a[i])的每个元素。换句话说,通过指定应从a的哪个索引中提取数据来创建新数组。同样,这些值又是原始a的别名,因此,如果您修改a[b],则每个a[b[i]]的{​​{1}}的值都会改变。 this问题中显示了一个示例用例。

integer array indexingboolean array indexing中为numpy解释了这两种模式,对于后者,您必须记住pytorch使用i代替uint8

此外,如果您的目标是将数据从一个张量复制到另一个张量,则必须记住,bool之类的操作是就地操作(a[ixs] = b[ixs]已就地修改),我不能与autograd一起玩。如果要进行位置遮蔽,请使用torch.wherethis答案中显示了一个示例用例。