torch.argmax如何在4维上工作

时间:2020-08-15 14:42:04

标签: python pytorch

我是Pytorch的新手。即使阅读了文档,对于我来说,尚不清楚在有4维输入的情况下torch.argmax()如何应用于一维。另外, keepdims = True 如何更改输出?

以下是每种情况的示例:

k = torch.rand(2, 3, 4, 4)
print(k):

tensor([[[[0.2912, 0.4818, 0.1123, 0.3196],
          [0.6606, 0.1547, 0.0368, 0.9475],
          [0.4753, 0.7428, 0.5931, 0.3615],
          [0.6729, 0.7069, 0.1569, 0.3086]],

         [[0.6603, 0.7777, 0.3546, 0.2850],
          [0.3681, 0.5295, 0.8812, 0.6093],
          [0.9165, 0.2842, 0.0260, 0.1768],
          [0.9371, 0.9889, 0.6936, 0.7018]],

         [[0.5880, 0.0349, 0.0419, 0.3913],
          [0.5884, 0.9408, 0.1707, 0.1893],
          [0.3260, 0.4410, 0.6369, 0.7331],
          [0.9448, 0.7130, 0.3914, 0.2775]]],


        [[[0.9433, 0.8610, 0.9936, 0.1314],
          [0.8627, 0.3103, 0.3066, 0.3547],
          [0.3396, 0.1892, 0.0385, 0.5542],
          [0.4943, 0.0256, 0.7875, 0.5562]],

         [[0.2338, 0.2498, 0.4749, 0.2520],
          [0.4405, 0.1605, 0.6219, 0.8955],
          [0.2326, 0.1816, 0.5032, 0.8732],
          [0.2089, 0.6131, 0.1898, 0.0517]],

         [[0.1472, 0.8059, 0.6958, 0.9047],
          [0.6403, 0.2875, 0.5746, 0.5908],
          [0.8668, 0.4602, 0.8224, 0.9307],
          [0.2077, 0.5665, 0.8671, 0.4365]]]])

argmax = torch.argmax(k, axis=1)
print(argmax):
tensor([[[1, 1, 1, 2],
         [0, 2, 1, 0],
         [1, 0, 2, 2],
         [2, 1, 1, 1]],

        [[0, 0, 0, 2],
         [0, 0, 1, 1],
         [2, 2, 2, 2],
         [0, 1, 2, 0]]])


argmax = torch.argmax(k, axis=1, keepdims=True)
print(argmax):
tensor([[[[1, 1, 1, 2],
          [0, 2, 1, 0],
          [1, 0, 2, 2],
          [2, 1, 1, 1]]],


        [[[0, 0, 0, 2],
          [0, 0, 1, 1],
          [2, 2, 2, 2],
          [0, 1, 2, 0]]]])

2 个答案:

答案 0 :(得分:2)

根据定义,如果k是形状为(2, 3, 4, 4)的张量,则torch.argmaxaxis=1一起应为形状为(2, 4, 4)的输出。要了解为什么会发生这种情况,您必须首先了解在较低维度中会发生什么。

如果我有一个2D(2,2)张量A,例如:

[[1,2],
 [3,4]]

然后torch.argmax(A, axis=1)给出形状(2)的输出,其值为(1,1)。 axis参数表示要沿其运行的轴。因此设置axis=1意味着它将在决定最大值之前逐一查看每一列中的值。对于第0行,它查看列值1、2,并确定2(在索引1处)为最大值。对于第1行,它查看列值3、4,并确定4(在索引1处)是最大值。因此argmax结果为[1,1]。

移动到3D,让我们假设一个尺寸数组(I,J,K)。如果我们以Axis = 1调用argmax,则可以将其分解为以下内容:

I, J, K = 3, 4, 5
A = torch.rand(I, J, K)
out = torch.zeros((I, K), dtype=torch.int32)

for i in range(I):
    for k in range(K):
        out[i,k] = torch.argmax(A[i,:,k])
        
print(out)
print(torch.argmax(A, axis=1))

Out:
tensor([[3, 3, 2, 3, 2],
        [1, 1, 0, 1, 0],
        [0, 1, 0, 3, 3]], dtype=torch.int32)
tensor([[3, 3, 2, 3, 2],
        [1, 1, 0, 1, 0],
        [0, 1, 0, 3, 3]])

所以发生的是,在您的3D张量中,您将再次沿列/轴1计算argmax。因此,对于(i,k)的每对唯一值,沿轴1都有正J值,对?这些J值中的最大值的索引插入到输出的位置(i,k)。

如果您了解这一点,那么您就可以了解4D中发生的情况。对于任何尺寸为(I,J,K,L)的4D张量,如果您用arg = 1调用argmax,那么对于(i,k,l)的每种组合,您将沿轴1精确获得J值-并且这些J值的argmax将出现在输出[i,k,l]。

keepdims参数仅保留矩阵的维数。例如,在4D矩阵上的轴1处的argmax给出形状为(I,K,L)的3D结果,但是使用keepdims时,形状为(I,1,K,L)的结果也将是4D。 / p>

答案 1 :(得分:0)

Argmax给出对应于给定维度上最大值的索引。因此尺寸数不是问题。因此,当您在给定维度上应用argmax时,默认情况下PyTorch会将该维度折叠起来,因为其值已由单个索引替换。现在,如果您不想删除该尺寸而将其保留为一个尺寸,则可以使用keepdims=True