我经历过official doc和this,但很难理解发生了什么。
我试图理解DQN的源代码,并且它使用了第197行的collect函数。
有人可以简单地解释一下collect函数的作用吗?该功能的目的是什么?
答案 0 :(得分:33)
torch.gather
函数(或torch.Tensor.gather
)是一种多索引选择方法。查看官方文档中的以下示例:
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1, 1],
# [ 4, 3]])
让我们从遍历不同参数的语义开始:第一个参数input
是我们要从中选择元素的源张量。第二个dim
是我们要收集的尺寸(或以tensorflow / numpy表示的轴)。最后,index
是索引input
的索引。
至于操作的语义,这是官方文档对其进行解释的方式:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
所以让我们看一下示例。
输入张量为[[1, 2], [3, 4]]
,并且dim参数为1
,即我们要从第二维进行收集。第二维的索引分别为[0, 0]
和[1, 0]
。
当我们“跳过”第一个维度(我们要收集的维度为1
)时,结果的第一个维度被隐式给出为index
的第一个维度。这意味着索引保留第二维或列索引,但不保留行索引。这些由index
张量本身的索引给出。
对于示例,这意味着输出将在其第一行中也选择input
张量的第一行的元素,如index
张量的第一行的第一行所给定。由于列索引由[0, 0]
给出,因此我们两次选择了输入第一行的第一个元素,结果为[1, 1]
。类似地,结果第二行的元素是通过input
张量的第二行的元素对index
张量的第二行进行索引的结果,从而得到[4, 3]
。
为进一步说明这一点,让我们在示例中交换尺寸:
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1, 2],
# [ 3, 2]])
如您所见,索引现在沿第一维收集。
对于您引用的示例,
current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))
gather
将按操作的批处理列表索引q值的行(即,一组q值中的每个样本q值)。结果将与您执行以下操作相同(尽管它比循环快得多):
q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)
答案 1 :(得分:18)
torch.gather
通过沿输入维dim
取每一行的值从输入张量创建新的张量。 torch.LongTensor
中的值作为index
传递,指定从每个“行”中获取哪个值。输出张量的尺寸与索引张量的尺寸相同。下图来自官方文档,对其进行了更清晰的说明:
(注意:在图中,索引从1开始而不是从0开始)。
在第一个示例中,给定的尺寸沿行(从上到下),因此对于result
的(1,1)位置,它取{{1 }}即index
。源值在(1,1)处为src
,因此,在1
中在(1,1)处输出1
。
同样,对于(2,2),来自1
的索引的行值为result
。在(3,2),src
中的值为3
,因此输出src
,依此类推。
类似地,对于第二个示例,沿列进行索引,因此在8
的(2,2)位置,来自8
的索引的列值为result
,因此从src
的(2,3)开始,3
被获取并输出到src
的(2,2)
答案 2 :(得分:9)
@Ritesh和@cleros给出了很好的答案(带有很多的赞成票),但是读完它们后我仍然有些困惑,我知道为什么。这篇文章也许会对像我这样的人有所帮助。
对于这类具有行和列的练习,我认为确实有助于使用非正方形对象,所以让我们从更大的4x3 source
(torch.Size([4, 3])
开始)使用source = torch.tensor([[1,2,3], [4,5,6], [7,8,9], [10,11,12]])
。这会给我们
\\ This is the source tensor
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
现在,让我们开始沿列(dim=1
)建立索引并创建index = torch.tensor([[0,0],[1,1],[2,2],[0,1]])
,这是一个列表列表。这是键:由于我们的维是列,并且源具有4
行,因此index
必须包含4
列表!我们需要为每一行提供一个列表。运行source.gather(dim=1, index=index)
将给我们
tensor([[ 1, 1],
[ 5, 5],
[ 9, 9],
[10, 11]])
因此,index
中的每个列表为我们提供了从中提取值的列。 index
([0,0]
)的第一列表告诉我们看source
的第一行,并对该行的第一列(索引为零)进行两次,即[1,1]
。 index
([1,1]
)的第二个列表告诉我们看source
的第二行,并将该行的第二列两次,即{{1} }。跳至[5,5]
(index
)的第4个列表,要求我们查看[0,1]
的第4行和最后一行,并要求我们采用第1列({ {1}}),然后是第二列(source
),我们得到10
。
这是一个很漂亮的事情:11
的每个列表必须具有相同的长度,但是它们的长度可以取决于您的需要!例如,对于[10,11]
,index
将给我们
index = torch.tensor([[0,1,2,1,0],[2,1,0,1,2],[1,2,0,2,1],[1,0,2,0,1]])
输出将始终具有与source.gather(dim=1, index=index)
相同的行数,但是列数将等于tensor([[ 1, 2, 3, 2, 1],
[ 6, 5, 4, 5, 6],
[ 8, 9, 7, 9, 8],
[11, 10, 12, 10, 11]])
中每个列表的长度。例如,source
(index
)的第二个列表将转到index
的第二行,分别拉出第三,第二,第一,第二和第三项,是[2,1,0,1,2]
。请注意,source
中每个元素的值必须小于[6,5,4,5,6]
的列数(在这种情况下为index
),否则会出现source
错误。
切换到3
,我们现在将使用行而不是列。使用相同的out of bounds
,我们现在需要一个dim=0
,其中每个列表的长度等于source
中的列数。为什么?因为当我们逐列移动时,列表中的每个元素都代表index
中的行。
因此,source
将有source
给我们
index = torch.tensor([[0,0,0],[0,1,2],[1,2,3],[3,2,0]])
查看source.gather(dim=0, index=index)
(tensor([[ 1, 2, 3],
[ 1, 5, 9],
[ 4, 8, 12],
[10, 8, 3]])
)中的第一个列表,我们可以看到我们正在index
的3列中移动,选择了第一个元素(索引为零) ),即[0,0,0]
。 source
([1,2,3]
)中的第二个列表告诉我们在列中分别移动第一项,第二项和第三项,即index
。依此类推。
使用[0,1,2]
时,我们的[1,5,9]
必须拥有与dim=1
中的行数相同的列表数量,但是每个列表可以根据您的喜好长短。 。使用index
时,source
中的每个列表的长度必须与dim=0
中的列数相同,但是现在我们可以拥有任意数量的列表。但是,index
中的每个值必须小于source
中的行数(在这种情况下为index
)。
例如,source
会让4
给我们
index = torch.tensor([[0,0,0],[1,1,1],[2,2,2],[3,3,3],[0,1,2],[1,2,3],[3,2,0]])
使用source.gather(dim=0, index=index)
时,输出的行数始终与tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12],
[ 1, 5, 9],
[ 4, 8, 12],
[10, 8, 3]])
相同,尽管列数将等于dim=1
中列表的长度。 source
中的列表数必须等于index
中的行数。但是,index
中的每个值必须小于source
中的列数。
对于index
,输出始终具有与source
相同的列数,但是行数将等于dim=0
中的列表数。 source
中每个列表的长度必须等于index
中的列数。但是,index
中的每个值必须小于source
中的行数。
仅此而已。超越这一点将遵循相同的模式。