numba.vectorize - 不支持的数组dtype

时间:2018-01-13 19:08:11

标签: python pandas numba

我是[5, 26]的新手,似乎无法确定传递给numba的参数。这就是我要做的事情:

vectorize

任务是根据以下逻辑创建一个新列,首先基于标准test = [x for x in range(10)] test2 = ['a', 'a', 'a', 'b', 'b', 'c', 'c', 'c', 'c', 'c'] test_df = pd.DataFrame({'test': test, 'test2': test2}) test_df['test3'] = np.where(test_df['test'].values % 2 == 0, test_df['test'].values, np.nan) test test2 test3 test4 0 0 a 0.0 0.0 1 1 a NaN NaN 2 2 a 2.0 4.0 3 3 b NaN NaN 4 4 b 4.0 16.0 5 5 c NaN NaN 6 6 c 6.0 36.0 7 7 c NaN NaN 8 8 c 8.0 64.0 9 9 c NaN NaN

pandas

使用def nonnumba_test(row): if row['test2'] == 'a': return row['test'] * row['test3'] else: return np.nan ;我了解到,使用applynp.where对象的.values属性可以更快地完成此任务,但希望针对Series对此进行测试。

numba

接下来,当我尝试使用test_df.apply(nonnumba_test, axis=1) 0 0.0 1 NaN 2 4.0 3 NaN 4 NaN 5 NaN 6 NaN 7 NaN 8 NaN 9 NaN dtype: float64 装饰器

numba.vectorize

我收到以下错误

@numba.vectorize()
def numba_test(x, y, z):
    if x == 'a':
        return y * z
    else:
        return np.nan

我想我需要在numba_test(test_df['test2'].values, test_df['test'].values, test_df['test3'].values) ValueError: Unsupported array dtype: object 参数中指定返回类型,但我似乎无法弄明白。

1 个答案:

答案 0 :(得分:5)

问题numba不容易支持字符串(see heresee here)。

解决方案是处理numba修饰函数之外的布尔逻辑if x=='a'。修改您的示例(numba_test和输入参数)如下所示会产生所需的输出(示例中最后两个块之上的所有内容都保持不变):

from numba import vectorize, float64, int64, boolean

#@vectorize() will also work here, but I think it's best practice with numba to specify types.
@vectorize([float64(boolean, int64, float64)])
def numba_test(x, y, z):
    if x:
        return y * z
    else:
        return np.nan

# now test it...
# NOTICE the boolean argument, **not** string!
numba_test(test_df['test2'].values =='a', 
           test_df['test'].values, 
           test_df['test3'].values)  

返回:

array([  0.,  nan,   4.,  nan,  nan,  nan,  nan,  nan,  nan,  nan])

根据需要。

最后的注释:您会看到我在上面的vectorize装饰器中指定了类型。是的,这有点烦人,但我认为这是最好的做法,因为它让你头疼完全就像这样:如果你指定了类型,你将无法找到字符串类型,那会解决它。