为了在Python中编写“分段函数”,我通常使用if
(以控制流或三元运算符的形式)。
def spam(x):
return x+1 if x>=0 else 1/(1-x)
现在,有了NumPy,我们的口头禅是为了避免性能,避免使用单值来支持矢量化。所以我认为这样的事情会更受欢迎: 正如Leon所言,以下是错误
def eggs(x):
y = np.zeros_like(x)
positive = x>=0
y[positive] = x+1
y[np.logical_not(positive)] = 1/(1-x)
return y
(纠正我,如果我在这里错过了一些东西,因为坦率地说我发现这很难看。)
现在,如果eggs
实际上是一个NumPy数组,当然x
只会 ,因为否则x>=0
只会产生一个布尔值,这可以&# 39;用于索引(至少没有做正确的事情)。
是否有一种很好的方法来编写看起来更像spam
但在Numpy数组上惯用的代码,或者我应该只使用vectorize(spam)
?
答案 0 :(得分:4)
使用np.where
。但是,即使对于普通数字输入,您也会得到一个数组作为输出。
def eggs(x):
y = np.asarray(x)
return np.where(y>=0, y+1, 1/(1-y))
这适用于数组和普通数字:
>>> eggs(5)
array(6.0)
>>> eggs(-3)
array(0.25)
>>> eggs(np.arange(-3, 3))
/home/praveen/.virtualenvs/numpy3-mkl/bin/ipython3:2: RuntimeWarning: divide by zero encountered in true_divide
array([ 0.25 , 0.33333333, 0.5 , 1. , 2. , 3. ])
>>> eggs(1)
/home/praveen/.virtualenvs/numpy3-mkl/bin/ipython3:3: RuntimeWarning: divide by zero encountered in long_scalars
# -*- coding: utf-8 -*-
array(2.0)
正如ayhan所言,这引发了一个警告,因为1/(1-x)
被评估了整个范围。但警告只是:警告。如果您知道自己在做什么,可以忽略该警告。在这种情况下,您只能从永远不会1/(1-x)
的索引中选择inf
,这样您才能安全。
答案 1 :(得分:2)
如果我想处理数字和numpy数组,我会使用numpy.asarray(如果参数已经是一个numpy数组,这是一个无操作)
def eggs(x):
x = np.asfarray(x)
m = x>=0
x[m] = x[m] + 1
x[~m] = 1 / (1 - x[~m])
return x
(这里我使用asfarray来强制执行浮点类型,因为你的函数需要浮点计算。)
对于单个输入,这比垃圾邮件功能效率低,可以说是丑陋的。但它似乎是最简单的选择。
编辑:如果你想确保x
没有被修改(如Leon指出的那样)你可以用np.asfarray(x)
替换np.array(x, dtype=np.float64)
,默认情况下数组构造函数会复制。