实现numpy 1.8和1.9 nansum之间的兼容性?

时间:2014-10-04 14:19:22

标签: numpy

我的代码需要具有与numpy版本完全相同的行为,但基础np.nansum函数已更改行为,使np.nansum([np.nan,np.nan])在{1.9} 0.0NaN中1.8。 < = 1.8行为是我更喜欢的行为,但更重要的是我的代码对于numpy版本是健壮的。

棘手的是,代码将任意numpy函数(通常是np.nan[something]函数)应用于ndarray。有没有办法强迫新的或旧的numpy nan[something]函数符合旧的或新的行为害羞monkeypatching它们?

我能想到的一个可能的解决方案就像outarr[np.allnan(inarr, axis=axis)] = np.nan,但没有np.allnan函数 - 如果这是最好的解决方案,那么这是最好的实现np.all(np.isnan(arr), axis=axis)(这将需要只支持np> = 1.7,但这可能没问题?

1 个答案:

答案 0 :(得分:1)

在Numpy 1.8中,nansum被定义为:

a, mask = _replace_nan(a, 0)

if mask is None:
    return np.sum(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims)
mask = np.all(mask, axis=axis, keepdims=keepdims)
tot = np.sum(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims)
if np.any(mask):
    tot = _copyto(tot, np.nan, mask)
    warnings.warn("In Numpy 1.9 the sum along empty slices will be zero.",
                  FutureWarning)
return tot

在Numpy 1.9中,它是:

a, mask = _replace_nan(a, 0)
return np.sum(a, axis=axis, dtype=dtype, out=out, keepdims=keepdims)

我认为有一种方法可以让新的nansum表现得很旧,但考虑到原来的nansum代码不是很长,你能否只提供一份如果您关心保留1.8之前的行为,那么该代码(没有警告)?

请注意,_copyto可以导入numpy.lib.nanfunctions