我是Pytorch的新手。即使阅读了文档,对于我来说,尚不清楚在有4维输入的情况下torch.argmax()如何应用于一维。另外, keepdims = True 如何更改输出?
以下是每种情况的示例:
k = torch.rand(2, 3, 4, 4)
print(k):
tensor([[[[0.2912, 0.4818, 0.1123, 0.3196],
[0.6606, 0.1547, 0.0368, 0.9475],
[0.4753, 0.7428, 0.5931, 0.3615],
[0.6729, 0.7069, 0.1569, 0.3086]],
[[0.6603, 0.7777, 0.3546, 0.2850],
[0.3681, 0.5295, 0.8812, 0.6093],
[0.9165, 0.2842, 0.0260, 0.1768],
[0.9371, 0.9889, 0.6936, 0.7018]],
[[0.5880, 0.0349, 0.0419, 0.3913],
[0.5884, 0.9408, 0.1707, 0.1893],
[0.3260, 0.4410, 0.6369, 0.7331],
[0.9448, 0.7130, 0.3914, 0.2775]]],
[[[0.9433, 0.8610, 0.9936, 0.1314],
[0.8627, 0.3103, 0.3066, 0.3547],
[0.3396, 0.1892, 0.0385, 0.5542],
[0.4943, 0.0256, 0.7875, 0.5562]],
[[0.2338, 0.2498, 0.4749, 0.2520],
[0.4405, 0.1605, 0.6219, 0.8955],
[0.2326, 0.1816, 0.5032, 0.8732],
[0.2089, 0.6131, 0.1898, 0.0517]],
[[0.1472, 0.8059, 0.6958, 0.9047],
[0.6403, 0.2875, 0.5746, 0.5908],
[0.8668, 0.4602, 0.8224, 0.9307],
[0.2077, 0.5665, 0.8671, 0.4365]]]])
argmax = torch.argmax(k, axis=1)
print(argmax):
tensor([[[1, 1, 1, 2],
[0, 2, 1, 0],
[1, 0, 2, 2],
[2, 1, 1, 1]],
[[0, 0, 0, 2],
[0, 0, 1, 1],
[2, 2, 2, 2],
[0, 1, 2, 0]]])
argmax = torch.argmax(k, axis=1, keepdims=True)
print(argmax):
tensor([[[[1, 1, 1, 2],
[0, 2, 1, 0],
[1, 0, 2, 2],
[2, 1, 1, 1]]],
[[[0, 0, 0, 2],
[0, 0, 1, 1],
[2, 2, 2, 2],
[0, 1, 2, 0]]]])
答案 0 :(得分:2)
根据定义,如果k
是形状为(2, 3, 4, 4)
的张量,则torch.argmax
与axis=1
一起应为形状为(2, 4, 4)
的输出。要了解为什么会发生这种情况,您必须首先了解在较低维度中会发生什么。
如果我有一个2D(2,2)张量A,例如:
[[1,2],
[3,4]]
然后torch.argmax(A, axis=1)
给出形状(2)的输出,其值为(1,1)。 axis参数表示要沿其运行的轴。因此设置axis=1
意味着它将在决定最大值之前逐一查看每一列中的值。对于第0行,它查看列值1、2,并确定2(在索引1处)为最大值。对于第1行,它查看列值3、4,并确定4(在索引1处)是最大值。因此argmax结果为[1,1]。
移动到3D,让我们假设一个尺寸数组(I,J,K)。如果我们以Axis = 1调用argmax,则可以将其分解为以下内容:
I, J, K = 3, 4, 5
A = torch.rand(I, J, K)
out = torch.zeros((I, K), dtype=torch.int32)
for i in range(I):
for k in range(K):
out[i,k] = torch.argmax(A[i,:,k])
print(out)
print(torch.argmax(A, axis=1))
Out:
tensor([[3, 3, 2, 3, 2],
[1, 1, 0, 1, 0],
[0, 1, 0, 3, 3]], dtype=torch.int32)
tensor([[3, 3, 2, 3, 2],
[1, 1, 0, 1, 0],
[0, 1, 0, 3, 3]])
所以发生的是,在您的3D张量中,您将再次沿列/轴1计算argmax。因此,对于(i,k)的每对唯一值,沿轴1都有正J值,对?这些J值中的最大值的索引插入到输出的位置(i,k)。
如果您了解这一点,那么您就可以了解4D中发生的情况。对于任何尺寸为(I,J,K,L)的4D张量,如果您用arg = 1调用argmax,那么对于(i,k,l)的每种组合,您将沿轴1精确获得J值-并且这些J值的argmax将出现在输出[i,k,l]。
keepdims
参数仅保留矩阵的维数。例如,在4D矩阵上的轴1处的argmax给出形状为(I,K,L)的3D结果,但是使用keepdims时,形状为(I,1,K,L)的结果也将是4D。 / p>
答案 1 :(得分:0)
Argmax给出对应于给定维度上最大值的索引。因此尺寸数不是问题。因此,当您在给定维度上应用argmax时,默认情况下PyTorch会将该维度折叠起来,因为其值已由单个索引替换。现在,如果您不想删除该尺寸而将其保留为一个尺寸,则可以使用keepdims=True
。