将numpy数组与标量与'out = ...'进行比较

时间:2018-03-17 11:35:08

标签: python performance numpy

我有一个字符串数组:

s = np.array(['a', 'b', 'c'])

我希望有一个函数array_equal_to_scalar来将s与字符串'a'进行比较,并将输出写入预分配数组(我需要快速性能):

mask = np.empty(s.shape)
np.array_equal_to_scalar(s, 'a', out=mask)

所以,我希望mask将是

> [True False False]

有没有办法制作类似array_equal_to_scalar的内容?

1 个答案:

答案 0 :(得分:3)

您正在寻找的是numpy.equal ufunc,它似乎不适用于您的用例。

为了以你想要的方式使用它,我们需要明确地将要比较的标量广播为适当形状的numpy数组:

import numpy as np

a = np.array(['a','b','c'])
res = np.empty(a.shape, dtype=bool)
np.equal(a, np.broadcast_to(['a'], a.shape), out=res)

不幸的是,上面的调用(1)忽略广播并给出一个恒定的结果,(2)是NotImplemented。我们可以尝试分配一个适当的比较数组来强制进行适当的元素比较,但无济于事:

>>> compare = np.full(a.shape, 'a')
>>> np.equal(a, compare)
NotImplemented

似乎通过numpy ufunc s的高效实现仅针对数字类型给出(我还没有时间查看源代码)。但我不希望更高级别的函数能够直接使用预分配的输入数组作为缓冲区。使用已编译的ufunc,我可以想象out关键字参数允许您绕过临时数组的创建,但我不认为这里有另一种选择。