在numpy数组中应用自定义函数

时间:2019-05-13 04:58:53

标签: pandas numpy filter

我有一个清单,

mylist=np.array([120,3,10,33,5,54,2,23,599,801])

和一个功能

def getSum(n): 
    n=n**2
    sum = 0
    while (n != 0): 

        sum = sum + int(n % 10) 
        n = int(n/10) 
    if sum <20:
        return True
    return False

我正在尝试将函数应用于mylist并仅检索那些正确的索引。

我的预期输出是

[120, 3, 10, 33, 5, 54, 2, 23, 801]

我可以像list(filter(getSum,mylist))那样在numpy中使用它。

尝试np.where未产生预期的输出。

5 个答案:

答案 0 :(得分:2)

如果要检查数字的总和是否为> 20,可以在这里找到纯numpy的解决方案(here可以找到如何分解数字中的整数):

import numpy as np


mylist=np.array([120,3,10,33,5,54,2,23,599,801])

mylist = mylist**2
max_digits = np.ceil(np.max(np.log10(mylist)))  # max number of digits in mylist
digits = mylist//(10**np.arange(max_digits)[:, None])%10  # matrix of digits
digitsum = np.sum(digits, axis=0)  # array of sums
mask = digitsum < 20
mask
# array([True, True, True, True, True, True, True, True, False, True])

更新:速度比较

@hpaulj在(几乎)所有建议的解决方案之间进行了很好的时间比较。
获胜者是filter,输入纯list,而我的 pure numpy 解决方案效果不佳。
无论如何,如果我们在更广泛的输入范围内对它们进行测试,情况就会改变。
这是使用@NicoSchlömer的perflot执行的测试。
对于100多个元素的输入,所有解决方案都是等效的,而纯粹的numpy更快: enter image description here

答案 1 :(得分:1)

我认为有循环,所以最好使用numba

from numba import jit
@jit(nopython=True)
def get_vals(arr):
    out = np.zeros(arr.shape[0], dtype=bool)
    for i, n in enumerate(arr):

        n=n**2
        sum1 = 0
        while (n != 0): 
            sum1 = sum1 + int(n % 10) 
            n = int(n/10) 
        if sum1 <20:
            out[i] = True
    return arr[out]

print(get_vals(mylist))

答案 2 :(得分:1)

使用list comprehensionnp.vectorize的基本概念是从文档进行循环(也不会提高性能):

mylist[[getSum(i) for i in mylist]]

array([120,   3,  10,  33,   5,  54,   2,  23, 801])

答案 3 :(得分:1)

函数和测试数组:

In [22]: def getSum(n):  
    ...:     n=n**2 
    ...:     sum = 0 
    ...:     while (n != 0):  
    ...:  
    ...:         sum = sum + int(n % 10)  
    ...:         n = int(n/10)  
    ...:     if sum <20: 
    ...:         return True 
    ...:     return False 
    ...:                                                                        
In [23]: mylist=np.array([120,3,10,33,5,54,2,23,599,801])                       

您的filter解决方案:

In [51]: list(filter(getSum, mylist))                                           
Out[51]: [120, 3, 10, 33, 5, 54, 2, 23, 801]

和示例计时:

In [52]: timeit list(filter(getSum, mylist))                                    
32.8 µs ± 185 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

由于这会返回列表并进行迭代,因此如果mylist是列表而不是数组,则应该更快:

In [53]: %%timeit alist=mylist.tolist() 
    ...: list(filter(getSum, alist))                                                                        
18.4 µs ± 378 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

替代品

您建议使用 np.vectorize

In [56]: f = np.vectorize(getSum); mylist[f(mylist)]                            
Out[56]: array([120,   3,  10,  33,   5,  54,   2,  23, 801])
In [57]: timeit f = np.vectorize(getSum); mylist[f(mylist)]                     
63.4 µs ± 151 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [58]: timeit mylist[f(mylist)]                                               
57.6 µs ± 920 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

糟糕!即使我们从时序循环中删除f创建,这也相当慢。 vectorize漂亮,但不能保证速度。

我发现 frompyfunc np.vectorize快(尽管它们是相关的):

In [59]: g = np.frompyfunc(getSum, 1,1)                                         
In [60]: g(mylist)                                                              
Out[60]: 
array([True, True, True, True, True, True, True, True, False, True],
      dtype=object)

结果是对象dtype,在这种情况下,必须将其转换为bool:

In [63]: timeit mylist[g(mylist).astype(bool)]                                  
25.5 µs ± 233 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

那比您的filter更好-但仅适用于数组而不是列表。

@Saandeep提出了列表理解

In [65]: timeit mylist[[getSum(i) for i in mylist]]                             
40.7 µs ± 1.21 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

那比您的filter慢一点。

使用列表理解的一种更快的方法是:

 [i for i in mylist if getSum(i)]

这一次与您的filter相同-对于阵列和列表版本(我失去了进行计时的会话)。

纯numpy

@lante提出了一种纯粹的numpy解决方案,虽然巧妙但有点晦涩。我还没有弄清楚逻辑:

def lante(mylist):
    max_digits = np.ceil(np.max(np.log10(mylist)))  # max number of digits in mylist
    digits = mylist//(10**np.arange(max_digits)[:, None])%10  # matrix of digits
    digitsum = np.sum(digits, axis=0)  # array of sums
    mask = digitsum > 20
    return mask

不幸的是不是速度恶魔:

In [69]: timeit mylist[~lante(mylist)]                                          
58.9 µs ± 757 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

我没有安装numba,所以无法安排@jezrael's解决方案。

因此,原始的filter是一个很好的解决方案,尤其是如果您从列表而不是数组开始。尤其是考虑转换时间时,好的Python列表解决方案通常比numpy更好。

在一个大例子中,时间可能会有所不同,但是我不希望有任何不适。

答案 4 :(得分:0)

vec=np.vectorize(getSum)
mylist[vec(mylist)]
out[]:
array([120,   3,  10,  33,   5,  54,   2,  23, 801])