在这个numpy logsumexp计算中避免无穷大

时间:2014-09-04 15:57:25

标签: python numpy infinity

Value error: truth value ambiguous跟进,我正在从这里编辑logsumexp函数:https://github.com/scipy/scipy/blob/v0.14.0/scipy/misc/common.py#L18

原因是因为: 1.我想自己选择最大值,它并不总是数组的最大值 2.我想提出一个条件,以确保从每个元素中减去最大值后的差异不会低于某个阈值。

这是我的最终代码。它没有任何问题 - 除了它有时仍然会返回它们的意思!

def mylogsumexp(self, a, is_class, maxaj=None, axis=None, b=None):
        threshold = -sys.float_info.max         
        a = asarray(a)
        if axis is None:
            a = a.ravel()
        else:
            a = rollaxis(a, axis)

        if is_class == 1:
            a_max = a.max(axis=0)
        else:
            a_max = maxaj  
        if b is not None:
            b = asarray(b)
            if axis is None:
                b = b.ravel()
            else:
                b = rollaxis(b, axis)
            #out = log(sum(b * exp(threshold if a - a_max < threshold else a - a_max), axis=0))
            out = np.log(np.sum(b * np.exp( np.minimum(a - a_max, threshold)), axis=0))

        else:
            out = np.log(np.sum(np.exp( np.minimum(a - a_max, threshold)), axis=0))
        out += a_max

1 个答案:

答案 0 :(得分:1)

您可以使用np.clip绑定数组的最大值和最小值:

>>> arr = np.arange(10)
>>> np.clip(arr, 3, 7)
array([3, 3, 3, 3, 4, 5, 6, 7, 7, 7])

在此示例中,大于7的值的上限为7;小于3的值设置为3。

如果我已正确解释您的代码,您可能需要替换

out = np.log(np.sum(b * np.exp( np.minimum(a - a_max, threshold)), axis=0))

out = np.log(np.sum(b * np.exp( np.clip(a - a_max, threshold, maximum)), axis=0))

其中maximum是您想要的最大值。