是否有一种将标量和0d数组更改为1d数组的pythonic方法?

时间:2014-08-23 04:50:05

标签: python numpy

例如,我想改变一个'如果' a'小于5

def foo(a):
    return 0 if a < 5 else a

使它适用于numpy数组,我改为:

def foo2(a):
    a[a < 5] = 0
    return a

问题是我希望该函数适用于标量和数组。

isscalar()函数可以测试是否&#39; a&#39;是一个标量,但它为0d数组返回false,如array(12)。

有没有pythonic方法将标量和0d数组更改为1d数组并保持其他ndarray不变?

6 个答案:

答案 0 :(得分:4)

好吧,我带来了似乎有效的解决方案

def foo3(a):
    return a * (a >= 5)

foo3(4)
=> 0

foo3(6)
=> 6

foo3(np.array(3))
=> 0

foo3(np.array(6))
=> 6

foo3(np.array([1, 5]))
=> array([0, 5])

它工作正常,但我不知道这样做是否安全。

答案 1 :(得分:2)

您可以使用numpy.vectorize和原始标量实现。

@np.vectorize
def foo(a):
   return 0 if a < 5 else a

foo(3)
=> array(0)
foo(7)
=> array(7)
foo([[3,7], [8,-1]])
=> array([[0, 7],
          [8, 0]])

请注意,在使用vectorize时,为了简单起见,您放弃了速度(您的计算不再在numpy / C级别进行矢量化)(您可以用简单的形式编写函数,用于标量)。

答案 2 :(得分:1)

如果你不介意函数返回一个数组,即使它有一个标量,我也倾向于这样做:

import numpy as np

def foo(a,k=5):
    b = np.array(a)
    if not b.ndim:
        b = np.array([a])
    b[b < k] = 0
    return b

print(foo(3))
print(foo(6))
print(foo([1,2,3,4,5,6,7,8,9]))
print(foo(np.array([1,2,3,4,5,6,7,8,9])))

...产生结果:

[0]
[6]
[0 0 0 0 5 6 7 8 9]
[0 0 0 0 5 6 7 8 9]

从测试示例中可以看出,如果它提供了常规的Python列表而不是numpy数组或标量,则此函数可以正常工作。

在此过程中创建两个数组可能看起来很浪费,但首先,创建b会阻止函数产生不必要的副作用。考虑一下:

def foobad(a,k=5):
    a[a < k] = 0
    return a

x = np.array([1,2,3,4,5,6,7,8,9])
foobad(x)
print (x)

...打印:

[0 0 0 0 5 6 7 8 9]

...这可能不是该功能的用户所期望的。其次,如果第二个数组创建是因为函数是用标量提供的,那么它只会从1个元素的列表中创建一个数组,这应该非常快。

答案 3 :(得分:0)

这是对问题最后部分的回答。在使用np.reshape检查维度后,使用np.ndim将标量或0d数组更改为1d数组的快速方法。

import numpy as np

a = 1
if np.ndim(a) == 0:
    np.reshape(a, (-1))
=> array([1])

然后,

b = np.array(1)
=> array(1) # 0d array
if np.ndim(b) == 0:
    np.reshape(b, (-1))
=> array([1]) # 1d array. you can iterate over this.

答案 4 :(得分:0)

试试这个

def foo(a, b=5):
    ix = a < b
    if not np.isscalar(ix):
        a[ix] = 0
    elif ix:
        a = 0
    return a

print([foo(1), foo(10), foo(np.array(1)), foo(np.array(10)), foo(np.arange(10))])

输出

[0, 10, 0, array(10), array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9])]

请注意array(1) > 0代替bool代替np.bool_,因此np.isscalar可以安全使用ix

答案 5 :(得分:0)

使用np.atleast_1d

这适用于任何输入(标量或数组):

def foo(a):
    a = np.atleast_1d(a)
    a[a < 5] = 0
    return a

但请注意,这将为标量输入返回1d数组。