用外行术语在pytorch中收集功能有什么作用?

时间:2018-06-23 09:54:01

标签: pytorch

我经历过official docthis,但很难理解发生了什么。

我试图理解DQN的源代码,并且它使用了第197行的collect函数。

有人可以简单地解释一下collect函数的作用吗?该功能的目的是什么?

3 个答案:

答案 0 :(得分:33)

torch.gather函数(或torch.Tensor.gather)是一种多索引选择方法。查看官方文档中的以下示例:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  1],
#        [ 4,  3]])

让我们从遍历不同参数的语义开始:第一个参数input是我们要从中选择元素的源张量。第二个dim是我们要收集的尺寸(或以tensorflow / numpy表示的轴)。最后,index是索引input的索引。 至于操作的语义,这是官方文档对其进行解释的方式:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

所以让我们看一下示例。

输入张量为[[1, 2], [3, 4]],并且dim参数为1,即我们要从第二维进行收集。第二维的索引分别为[0, 0][1, 0]

当我们“跳过”第一个维度(我们要收集的维度为1)时,结果的第一个维度被隐式给出为index的第一个维度。这意味着索引保留第二维或列索引,但不保留行索引。这些由index张量本身的索引给出。 对于示例,这意味着输出将在其第一行中也选择input张量的第一行的元素,如index张量的第一行的第一行所给定。由于列索引由[0, 0]给出,因此我们两次选择了输入第一行的第一个元素,结果为[1, 1]。类似地,结果第二行的元素是通过input张量的第二行的元素对index张量的第二行进行索引的结果,从而得到[4, 3]

为进一步说明这一点,让我们在示例中交换尺寸:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  2],
#        [ 3,  2]])

如您所见,索引现在沿第一维收集。

对于您引用的示例,

current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))

gather将按操作的批处理列表索引q值的行(即,一组q值中的每个样本q值)。结果将与您执行以下操作相同(尽管它比循环快得多):

q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
    q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)

答案 1 :(得分:18)

torch.gather通过沿输入维dim取每一行的值从输入张量创建新的张量。 torch.LongTensor中的值作为index传递,指定从每个“行”中获取哪个值。输出张量的尺寸与索引张量的尺寸相同。下图来自官方文档,对其进行了更清晰的说明: Pictoral representation from the docs

(注意:在图中,索引从1开始而不是从0开始)。

在第一个示例中,给定的尺寸沿行(从上到下),因此对于result的(1,1)位置,它取{{1 }}即index。源值在(1,1)处为src,因此,在1中在(1,1)处输出1。 同样,对于(2,2),来自1的索引的行值为result。在(3,2),src中的值为3,因此输出src,依此类推。

类似地,对于第二个示例,沿列进行索引,因此在8的(2,2)位置,来自8的索引的列值为result,因此从src的(2,3)开始,3被获取并输出到src的(2,2)

答案 2 :(得分:9)

@Ritesh和@cleros给出了很好的答案(带有很多的赞成票),但是读完它们后我仍然有些困惑,我知道为什么。这篇文章也许会对像我这样的人有所帮助。

对于这类具有行和列的练习,我认为确实有助于使用非正方形对象,所以让我们从更大的4x3 sourcetorch.Size([4, 3])开始)使用source = torch.tensor([[1,2,3], [4,5,6], [7,8,9], [10,11,12]])。这会给我们

\\ This is the source tensor
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])

现在,让我们开始沿列(dim=1)建立索引并创建index = torch.tensor([[0,0],[1,1],[2,2],[0,1]]),这是一个列表列表。这是:由于我们的维是列,并且源具有4行,因此index必须包含4列表!我们需要为每一行提供一个列表。运行source.gather(dim=1, index=index)将给我们

tensor([[ 1,  1],
        [ 5,  5],
        [ 9,  9],
        [10, 11]])

因此,index中的每个列表为我们提供了从中提取值的列。 index[0,0])的第一列表告诉我们看source的第一行,并对该行的第一列(索引为零)进行两次,即[1,1]index[1,1])的第二个列表告诉我们看source的第二行,并将该行的第二列两次,即{{1} }。跳至[5,5]index)的第4个列表,要求我们查看[0,1]的第4行和最后一行,并要求我们采用第1列({ {1}}),然后是第二列(source),我们得到10

这是一个很漂亮的事情:11的每个列表必须具有相同的长度,但是它们的长度可以取决于您的需要!例如,对于[10,11]index将给我们

index = torch.tensor([[0,1,2,1,0],[2,1,0,1,2],[1,2,0,2,1],[1,0,2,0,1]])

输出将始终具有与source.gather(dim=1, index=index)相同的行数,但是列数将等于tensor([[ 1, 2, 3, 2, 1], [ 6, 5, 4, 5, 6], [ 8, 9, 7, 9, 8], [11, 10, 12, 10, 11]]) 中每个列表的长度。例如,sourceindex)的第二个列表将转到index的第二行,分别拉出第三,第二,第一,第二和第三项,是[2,1,0,1,2]。请注意,source中每个元素的值必须小于[6,5,4,5,6]的列数(在这种情况下为index),否则会出现source错误。

切换到3,我们现在将使用行而不是列。使用相同的out of bounds,我们现在需要一个dim=0,其中每个列表的长度等于source中的列数。为什么?因为当我们逐列移动时,列表中的每个元素都代表index中的行。

因此,source将有source给我们

index = torch.tensor([[0,0,0],[0,1,2],[1,2,3],[3,2,0]])

查看source.gather(dim=0, index=index)tensor([[ 1, 2, 3], [ 1, 5, 9], [ 4, 8, 12], [10, 8, 3]]) )中的第一个列表,我们可以看到我们正在index的3列中移动,选择了第一个元素(索引为零) ),即[0,0,0]source[1,2,3])中的第二个列表告诉我们在列中分别移动第一项,第二项和第三项,即index。依此类推。

使用[0,1,2]时,我们的[1,5,9]必须拥有与dim=1中的行数相同的列表数量,但是每个列表可以根据您的喜好长短。 。使用index时,source中的每个列表的长度必须与dim=0中的列数相同,但是现在我们可以拥有任意数量的列表。但是,index中的每个值必须小于source中的行数(在这种情况下为index)。

例如,source会让4给我们

index = torch.tensor([[0,0,0],[1,1,1],[2,2,2],[3,3,3],[0,1,2],[1,2,3],[3,2,0]])

使用source.gather(dim=0, index=index)时,输出的行数始终与tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12], [ 1, 5, 9], [ 4, 8, 12], [10, 8, 3]]) 相同,尽管列数将等于dim=1中列表的长度。 source中的列表数必须等于index中的行数。但是,index中的每个值必须小于source中的列数。

对于index,输出始终具有与source相同的列数,但是行数将等于dim=0中的列表数。 source中每个列表的长度必须等于index中的列数。但是,index中的每个值必须小于source中的行数。

仅此而已。超越这一点将遵循相同的模式。