keepdims在Numpy(Python)中的作用是什么?

时间:2016-12-02 07:43:56

标签: python numpy

当我使用np.sum时,我遇到了一个名为keepdims的参数。查看the docs后,我仍然无法理解keepdims的含义。

  

keepdims:bool,可选

     

如果将其设置为True,则缩小的轴将作为尺寸为1的尺寸保留在结果中。使用此选项,结果将正确地针对原始arr。

进行广播

如果有人能用一个简单的例子来理解这一点,我将不胜感激。

4 个答案:

答案 0 :(得分:13)

考虑一个小的2d数组:

In [180]: A=np.arange(12).reshape(3,4)
In [181]: A
Out[181]: 
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])

跨行总和;结果是(3,)数组

In [182]: A.sum(axis=1)
Out[182]: array([ 6, 22, 38])

但要A加{(1)} sum需要重塑

In [183]: A-A.sum(axis=1)
...
ValueError: operands could not be broadcast together with shapes (3,4) (3,) 
In [184]: A-A.sum(axis=1)[:,None]   # turn sum into (3,1)
Out[184]: 
array([[ -6,  -5,  -4,  -3],
       [-18, -17, -16, -15],
       [-30, -29, -28, -27]])

如果我使用keepdims,“结果将正确播放”A

In [185]: A.sum(axis=1, keepdims=True)   # (3,1) array
Out[185]: 
array([[ 6],
       [22],
       [38]])
In [186]: A-A.sum(axis=1, keepdims=True)
Out[186]: 
array([[ -6,  -5,  -4,  -3],
       [-18, -17, -16, -15],
       [-30, -29, -28, -27]])

如果我相反,我不需要keepdims。广播此总和是自动的:A.sum(axis=0)[None,:]。但是使用keepdims没有任何害处。

In [190]: A.sum(axis=0)
Out[190]: array([12, 15, 18, 21])    # (4,)
In [191]: A-A.sum(axis=0)
Out[191]: 
array([[-12, -14, -16, -18],
       [ -8, -10, -12, -14],
       [ -4,  -6,  -8, -10]])

如果您愿意,可以使用np.mean使这些操作更有意义,将数组规范化为列或行。在任何情况下,它都可以简化原始数组和sum / mean之间的进一步数学运算。

答案 1 :(得分:0)

如果您对矩阵求和,则可以使用"keepdims=True"保留维度 例如:

import numpy as np
x  = np.array([[1,2,3],[4,5,6]])
x.shape
# (2, 3)

np.sum(x, keepdims=True).shape
# (1, 1)
np.sum(x, keepdims=True)
# array([[21]]) <---the reault is still a 1x1 array

np.sum(x, keepdims=False).shape
# ()
np.sum(x, keepdims=False)
# 21 <--- the result is an integer with no dimesion

答案 2 :(得分:0)

keepdims = True,用于匹配矩阵的尺寸。如果我们将此设置保留为False,则将显示尺寸错误。 您可以在计算softmax熵时看到它

答案 3 :(得分:0)

keepdims = true;在这种情况下,将保存数组(矩阵)的尺寸。这意味着您获得的结果将针对您要实现这些方法的Array正确“广播”。

当您忽略它时,它只是一个普通的数组,不再有任何维度。

import numpy as np

x = np.random.rand(4,3)

#Output for below statement: (3,)
print((np.sum(x, axis=0)).shape)

#Output for below statement: (1, 3)
print((np.sum(x, axis=0, keepdims=True)).shape)