了解numpy.where和等效替代方案的运行时

时间:2016-07-11 18:04:40

标签: python arrays numpy

根据http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html,如果给出x和y并且输入数组是1-D,则其等于[xv if c else yv for (c,xv, yv) in zip(x!=0, 1/x, x)]。但是,在进行运行时基准测试时,它们的速度差异很大:

x = np.array(range(-500, 500))

%timeit np.where(x != 0, 1/x, x)
10000 loops, best of 3: 23.9 µs per loop

%timeit [xv if c else yv for (c,xv, yv) in zip(x!=0, 1/x, x)]
1000 loops, best of 3: 232 µs per loop

有没有办法可以重写第二个表单,以便它与第一个表单具有相似的运行时间?我问的原因是因为我想使用第二种情况的略微修改版本来避免除以零错误:

[1 / xv if c else xv for (c,xv) in zip(x!=0, x)]

另一个问题:第一种情况返回一个numpy数组,而第二种情况返回一个列表。让第二种情况返回数组的最有效方法是首先创建一个列表然后将列表转换为数组吗?

np.array([xv if c else yv for (c,xv, yv) in zip(x!=0, 1/x, x)])

谢谢!

4 个答案:

答案 0 :(得分:3)

你刚刚问过“推迟”' ' where':

numpy.where : how to delay evaluating parameters?

而其他人刚刚问过除零:

Replace all elements of a matrix by their inverses

当人们说where与列表理解类似时,他们会尝试描述行动,而不是实际的实施。

仅使用一个参数调用的

np.wherenp.nonzero相同。这很快(在编译的代码中)循环遍历参数,并收集所有非零值的索引。

当使用3个参数调用

np.where时,返回一个新数组,根据nonzero值从第2和第3个参数中收集值。但要意识到这些参数必须是其他数组,这一点很重要。它们不是逐个元素评估的函数。

所以where更像是:

m1 = 1/xv
m2 = xv
[v1 if c else v2 for (c, v1, v2) in zip(x!=0, m1, m2)]

在编译代码中运行此迭代很容易,因为它只涉及3个匹配大小的数组(通过广播匹配)。

np.array([...])是将列表(或列表推导)转换为数组的合理方法。它可能比某些替代方案慢一点,因为np.array是一个强大的通用功能。 np.fromiter([], dtype)在某些情况下可能会更快,因为它不是一般的(您必须指定dtype,它只适用于1d)。

有两种经过时间验证的策略可以在逐个元素的计算中获得更快的速度:

  • 使用numbacython等软件包将问题重写为c代码

  • 重新计算您的计算以使用现有的numpy方法。使用掩蔽来避免被零除以是一个很好的例子。

=====================

np.ma.where,掩码数组的版本是用Python编写的。它的代码可能很有启发性。特别注意这篇文章:

# Construct an empty array and fill it
d = np.empty(fc.shape, dtype=ndtype).view(MaskedArray)
np.copyto(d._data, xv.astype(ndtype), where=fc)
np.copyto(d._data, yv.astype(ndtype), where=notfc)

它生成目标,然后根据条件数组选择性地复制2个输入数组中的值。

答案 1 :(得分:1)

通过使用高级索引来保持性能,可以避免被零除:

x = np.arange(-500, 500)

result = np.empty(x.shape, dtype=float) # set the dtype to whatever is appropriate
nonzero = x != 0
result[nonzero] = 1/x[nonzero]
result[~nonzero] = 0

答案 2 :(得分:1)

如果由于某种原因想要绕过numpy的错误,可能值得查看errstate上下文:

x  = np.array(range(-500, 500))

with np.errstate(divide='ignore'): #ignore zero-division error
    x = 1/x
x[x!=x] = 0 #convert inf and NaN's to 0

答案 3 :(得分:0)

考虑使用np.put()

更改阵列到位
In [56]: x = np.linspace(-1, 1, 5)

In [57]: x
Out[57]: array([-1. , -0.5,  0. ,  0.5,  1. ])

In [58]: indices = np.argwhere(x != 0)

In [59]: indices
Out[59]:
array([[0],
       [1],
       [3],
       [4]], dtype=int64)

In [60]: np.put(x, indices, 1/x[indices])

In [61]: x
Out[61]: array([-1., -2.,  0.,  2.,  1.])

上述方法不会创建新数组,如果x是一个大数组,这可能会非常方便。