如何编写与普通Python值和NumPy数组兼容的条件代码?

时间:2016-09-08 11:32:16

标签: python arrays numpy vectorization

为了在Python中编写“分段函数”,我通常使用if(以控制流或三元运算符的形式)。

def spam(x):
    return x+1 if x>=0 else 1/(1-x)

现在,有了NumPy,我们的口头禅是为了避免性能,避免使用单值来支持矢量化。所以我认为这样的事情会更受欢迎: 正如Leon所言,以下是错误

def eggs(x):
    y = np.zeros_like(x)
    positive = x>=0
    y[positive] = x+1
    y[np.logical_not(positive)] = 1/(1-x)
    return y

(纠正我,如果我在这里错过了一些东西,因为坦率地说我发现这很难看。)

现在,如果eggs实际上是一个NumPy数组,当然x只会 ,因为否则x>=0只会产生一个布尔值,这可以&# 39;用于索引(至少没有做正确的事情)。

是否有一种很好的方法来编写看起来更像spam但在Numpy数组上惯用的代码,或者我应该只使用vectorize(spam)

2 个答案:

答案 0 :(得分:4)

使用np.where。但是,即使对于普通数字输入,您也会得到一个数组作为输出。

def eggs(x):
    y = np.asarray(x)
    return np.where(y>=0, y+1, 1/(1-y))

这适用于数组和普通数字:

>>> eggs(5)
array(6.0)
>>> eggs(-3)
array(0.25)
>>> eggs(np.arange(-3, 3))
/home/praveen/.virtualenvs/numpy3-mkl/bin/ipython3:2: RuntimeWarning: divide by zero encountered in true_divide
array([ 0.25      ,  0.33333333,  0.5       ,  1.        ,  2.        ,  3.        ])
>>> eggs(1)
/home/praveen/.virtualenvs/numpy3-mkl/bin/ipython3:3: RuntimeWarning: divide by zero encountered in long_scalars
  # -*- coding: utf-8 -*-
array(2.0)

正如ayhan所言,这引发了一个警告,因为1/(1-x)被评估了整个范围。但警告只是:警告。如果您知道自己在做什么,可以忽略该警告。在这种情况下,您只能从永远不会1/(1-x)的索引中选择inf,这样您才能安全。

答案 1 :(得分:2)

如果我想处理数字和numpy数组,我会使用numpy.asarray(如果参数已经是一个numpy数组,这是一个无操作)

def eggs(x):
    x = np.asfarray(x)
    m = x>=0
    x[m] = x[m] + 1
    x[~m] = 1 / (1 - x[~m])
    return x

(这里我使用asfarray来强制执行浮点类型,因为你的函数需要浮点计算。)

对于单个输入,这比垃圾邮件功能效率低,可以说是丑陋的。但它似乎是最简单的选择。

编辑:如果你想确保x没有被修改(如Leon指出的那样)你可以用np.asfarray(x)替换np.array(x, dtype=np.float64),默认情况下数组构造函数会复制。