使用方括号将火炬张量子集化

时间:2020-02-13 06:44:16

标签: python pytorch

我在PyTorch中遇到了用于将3D张量缩减为2D张量的代码行。 3D张量x的大小为torch.Size([500, 50, 1]),此行代码为:

x = x[lengths - 1, range(len(lengths))]

用于将x缩小为大小torch.Size([50, 1])的2D张量。 lengths也是包含值的形状torch.Size([50])的张量。

请任何人解释这是如何工作的?谢谢。

2 个答案:

答案 0 :(得分:2)

在被这种行为所困扰之后,我做了一些进一步的研究,发现它是consistent behavior with the indexing of multi-dimensional NumPy arrays。造成这种不直观的原因是,两个数组具有的长度相同,即在这种情况下为len(lengths)

实际上,它的工作方式如下: * lengths正在确定您访问第一个维度的顺序。也就是说,如果您拥有一维数组a = [0, 1, 2, ...., 500],并使用列表b = [300, 200, 100]进行访问,则返回结果a[b] = [301, 201, 101](这也解释了lengths - 1运算符,它只会导致访问的值分别与blengths中使用的索引相同)。 * range(len(lengths))然后*简单地选择第i行中的第i个元素。如果您有正方形矩阵,则可以将其解释为矩阵的对角线。由于您仅沿前两个维度访问每个位置的单个元素,因此可以将其存储在单个维度中(因此将3D张量减小为2D)。后者维只是按“原样”保存。

如果您想解决这个问题,强烈建议将range()值更改为更长或更短的值,这将导致以下错误:

IndexError:形状不匹配:无法广播索引数组 连同形状(x,)(y,)

其中xy是您的特定长度值。

要以长形式写出此访问方法以了解“幕后”的情况,请考虑以下示例:

import torch
x = torch.randint(500, 50, 1)
lengths = torch.tensor([2, 30, 1, 4])  # random examples to explore
diag = list(range(len(lengths)))  # [0, 1, 2, 3]
result = []
for i, row in enumerate(lengths):
    temp_tensor = x[row, :, :]  # temp_tensor.shape = [1, 50, 1]
    temp_tensor = temp_tensor.squeeze(0)[diag[i]]  # temp_tensor.shape = [1, 1]
    result.append(temp.tensor)

# back to pytorch
result = torch.tensor(result)
result.shape  # [4, 1]

答案 1 :(得分:2)

此处的关键功能是传递张量lengths的值作为x的索引。 在此简化的示例中,我交换了容器的尺寸,因此索引维度排在第一位:

container = torch.arange(0, 50 )
container = f.reshape((5, 10))
>>>tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])

indices = torch.arange( 2, 7, dtype=torch.long )
>>>tensor([2, 3, 4, 5, 6])

print( container[ range( len(indices) ), indices] )
>>>tensor([ 2, 13, 24, 35, 46])    

注意:我们从一行中得到一件事(range( len(indices) )产生连续的行号),列号由索引[row_number]

给出