Numpy sum keepdims错误

时间:2017-08-09 21:47:35

标签: python numpy

Python在矩阵上调用numpy sum函数时抛出错误。

probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)

错误

probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
TypeError: sum() got an unexpected keyword argument 'keepdims'

上下文:计算softmax分类器的损失函数。 Numerator是正确类的得分函数的指数,而分母是所有可能类的所有指数的总和。

2 个答案:

答案 0 :(得分:2)

该论点在最新版本的numpy中有效,如here所述。以下是numpy.sum的完整参数列表:

  

numpy.sum(a,axis = None,dtype = None,out = None,keepdims = False)

这是自版本1.7以来添加的,您可以在源代码here中看到。因此,您需要升级您的numpy安装。

答案 1 :(得分:2)

在NumPy 1.7中添加了keepdims参数。至少np.sum (1.6)的文档字符串没有将其列为参数之一:

numpy.sum(a, axis=None, dtype=None, out=None)

但是1.7 docstring已经列出了它:

numpy.sum(a, axis=None, dtype=None, out=None, keepdims=False)

鉴于NumPy 1.6已在2012中发布,您可能应该更新您的NumPy包。

但是,如果您不能(或不想)更新NumPy,您也可以使用np.expand_dims

np.expand_dims(np.sum(exp_scores, axis=1), axis=1)