处理标量或数组的Python函数

时间:2015-03-28 14:44:20

标签: python arrays function numpy

如何最好地编写一个可以接受标量浮点数或numpy向量(1-d数组)的函数,并返回标量,1-d数组或2-d数组,具体取决于输入?

该函数很昂贵且经常被调用,我不想给调用者施加负担来对参数或返回值进行特殊强制转换。它只需要处理数字(不是列表或其他可迭代的东西)。

np.vectorize可能很慢(Broadcasting a python function on to numpy arrays),其他答案(Getting a Python function to cleanly return a scalar or list, depending on number of arguments)和np.asarray(A python function that accepts as an argument either a scalar or a numpy array)无法获得输出数组所需的维度。< / p>

这种类型的代码可以在Matlab,Javascript和其他语言中使用:

import numpy as np

def func( xa, ya ):
    # naively, I thought I could do:
    xy = np.zeros( ( len(xa), len(ya) ) )
    for j in range(len( ya )):
        for i in range(len( xa )):
            # do something complicated
            xy[i,j] = x[i]+y[j]            
    return xy

适用于数组:

x = np.array([1., 2.])
y = np.array([2., 4.])
xy = func(x,y)
print xy

[[ 3.  5.]
 [ 4.  6.]]

但是对于标量浮点数不起作用:

x = 1.
y = 3.
xy = func(x,y)
print xy

<ipython-input-64-0f85ad330609> in func(xa, ya)
      4 def func( xa, ya ):
      5     # naively, I thought I could do:
----> 6     xy = np.zeros( ( len(xa), len(ya) ) )
      7     for j in range(len( ya )):
      8         for i in range(len( xa )):

TypeError: object of type 'float' has no len()

在类似的函数中使用np.asarray给出:

<ipython-input-63-9ae8e50196e1> in func(x, y)
      5     xa = np.asarray( x );
      6     ya = np.asarray( y );
----> 7     xy = np.zeros( ( len(xa), len(ya) ) )
      8     for j in range(len( ya )):
      9         for i in range(len( xa )):

TypeError: len() of unsized object

什么是快速,优雅和pythonic的方法?

8 个答案:

答案 0 :(得分:12)

在整个numpy代码库中你会找到类似的东西:

def func_for_scalars_or_vectors(x):
    x = np.asarray(x)
    scalar_input = False
    if x.ndim == 0:
        x = x[None]  # Makes x 1D
        scalar_input = True

    # The magic happens here

    if scalar_input:
        return np.squeeze(ret)
    return ret

答案 1 :(得分:2)

&#34;可以接受标量浮点数或numpy向量(1-d数组)的函数,并返回标量,1-d数组或2-d数组&#34;

所以

  

标量=&gt;标量

     

1d =&gt; 2D

什么产生一维阵列?

def func( xa, ya ):
    def something_complicated(x, y):
        return x + y
    try:
        xy = np.zeros( ( len(xa), len(ya) ) )
        for j in range(len( ya )):
            for i in range(len( xa )):
                xy[i,j] = something_complicated(xa[i], ya[i])
    except TypeError:
        xy = something_complicated(xa, ya)  
    return xy

这是&#39;快速,优雅,和pythonic&#39;?

肯定是&#39; pythonic&#39;。 &#39;尝试/除&#39;是非常Pythonic。所以在另一个函数中定义一个函数。

快?只有时间测试会告诉你。它可能取决于标量对阵列示例的相对频率。

优雅?这是在旁观者的眼中。

这更优雅吗?这是有限的递归

def func( xa, ya ):
    try:
        shape = len(xa), len(ya)
    except TypeError:
        # do something complicated
        return xa+ya    
    xy = np.zeros(shape)
    for j in range(len( ya )):
        for i in range(len( xa )):
            xy[i,j] = func(xa[i], ya[i])           
    return xy

如果您需要正确处理2d +输入,那么vectorize显然是最省力的解决方案:

def something_complicated(x,y):
    return x+y
vsomething=np.vectorize(something_complicated)

In [409]: vsomething([1,2],[4,4])
Out[409]: array([5, 6])
In [410]: vsomething(1,3)
Out[410]: array(4)   # not quite a scalar

如果array(4)不是您想要的scalar输出,那么您必须添加测试并使用[()]提取值。 vectorize也处理标量和数组的混合(标量+ 1d =&gt; 1d)。

MATLAB没有标量。 size(3)会返回1,1

在Javascript中,[1,2,3]具有.length属性,但3没有。

来自nodejs会话的

> x.length
undefined
> x=[1,2,3]
[ 1, 2, 3 ]
> x.length
3

关于MATAB代码,Octave有关length函数

的说法
  

- 内置功能:长度(A)        返回对象A的长度。

     

空对象的长度为0,标量的长度为1,数量为        向量的元素。对于矩阵对象,长度是数字        行或列,以较大者为准(这个奇怪的定义是        用于与MATLAB兼容)。

MATLAB没有真正的标量。一切都至少2d。 A&#39; vector&#39;只有一个&#39; 1&#39;尺寸。 length是MATLAB中迭代控制的不良选择。我一直使用size

为了增加MATLAB的便利性,但也有可能产生混淆,x(i)适用于行&#39;向量&#39;和列&#39;向量&#39;,[1,2,3][1;2;3]x(i,j)也适用于两者,但具有不同的索引范围。

len在迭代Python列表时工作正常,但在使用numpy数组时不是最佳选择。如果您想要总项数,x.size会更好。如果您想要第一维,x.shape[0]会更好。

为什么没有一个优雅的Pythonic解决方案来解决你的问题的一部分原因是你从一些惯用的MATLAB开始,并期望Python表现出同样的细微差别。

答案 2 :(得分:1)

作为一个观点,我更希望函数在输入类型上具有灵活性,但总是返回一致的类型;这将最终阻止呼叫者检查返回类型(所述目标)。

例如,允许标量和/或数组作为参数,但始终返回数组。

def func(x, y):
    # allow (x=1,y=2) OR (x=[1,2], y=[3,4]) OR (!) (x=1,y=[2,3])
    xn = np.asarray([x]) if np.isscalar(x) else np.asarray(x)
    yn = np.asarray([y]) if np.isscalar(y) else np.asarray(y)

    # calculations with numpy arrays xn and xy
    res = xn + yn  # ..etc...
    return res

(尽管如此,通过设置标记“scalar=True”,yada yada yada,可以很容易地修改示例以返回标量..但是你还必须处理一个arg的标量,另一个是一个数组等;对我来说似乎有很多YAGNI。)

答案 3 :(得分:0)

我会做以下事情:

def func( xa, ya ):
    xalen = xa if type(xa) is not list else len(xa)
    yalen = ya if type(ya) is not list else len(ya)
    xy = np.zeros( (xalen, yalen) )    
    for j in range(yalen): 
        for i in range(xalen):
            xy[i,j] = x[i]+y[j] 
    return xy

答案 4 :(得分:0)

也许这不是最pythonic(也不是最快的),但它是最笨拙的方式:

import numpy as np

def func(xa, ya):
    xa, ya = map(np.atleast_1d, (xa, ya))
    # Naively, I thought I could do:
    xy = np.zeros((len(xa), len(ya)))
    for j in range(len(ya)):
        for i in range(len(xa)):
            # Do something complicated.
            xy[i,j] = xa[i] + ya[j]
    return xy.squeeze()

如果你正在寻找速度检查numba。

答案 5 :(得分:0)

首先将您的功能编写为不关心维度:

def func(xa, ya):
    # use x.shape, not len(x)
    xy = np.zeros(xa.shape + ya.shape)

    # use ndindex, not range
    for jj in np.ndindex(ya.shape):
        for ii in np.ndindex(xa.shape):
            # do something complicated
            xy[ii + jj] = x[ii] + y[jj]            
    return xy

答案 6 :(得分:0)

这就像函数装饰器可以执行的操作,如果您想将此行为应用于您编写的许多函数,就像我一样。我写了一个。原来比我希望的更杂乱,但是在这里。

当然,我们可能都应该编写明确采用并仅返回数组或标量的代码。显式胜于隐式。谨慎使用。

#splits string according to delimeters 
'''
Let's make a function that can split a string
into list according the given delimeters. 
example data: cat;dog:greff,snake/
example delimeters: ,;- /|:
'''
def string_to_splitted_array(data,delimeters):
    #result list
    res = []
    # we will add chars into sub_str until
    # reach a delimeter
    sub_str = ''
    for c in data: #iterate over data char by char
        # if we reached a delimeter, we store the result 
        if c in delimeters: 
            # avoid empty strings
            if len(sub_str)>0:
                # looks like a valid string.
                res.append(sub_str)
                # reset sub_str to start over
                sub_str = ''
        else:
            # c is not a deilmeter. then it is 
            # part of the string.
            sub_str += c
    # there may not be delimeter at end of data. 
    # if sub_str is not empty, we should att it to list. 
    if len(sub_str)>0:
        res.append(sub_str)
    # result is in res 
    return res

# test the function. 
delimeters = ',;- /|:'
# read the csv data from console. 
csv_string = input('csv string:')
#lets check if working. 
splitted_array = string_to_splitted_array(csv_string,delimeters)
print(splitted_array)

答案 7 :(得分:0)

尝试一下:

def func( xa, ya ):
    if not np.isscalar(xa):
        xa = np.array(xa)[:, None]   
    xy = xa + np.array(ya)
    return xy

输出:

> func([1, 2], [2, 4])
array([[3, 5],
       [4, 6]])

> func(3, [2, 4])
array([5, 7])

> func([2, 4], 3)
array([[5],
       [7]])