在Cython中使用小写的unicode字符串数组的最快方法

时间:2017-12-27 21:07:50

标签: python numpy unicode cython

Numpy的字符串函数都非常慢,并且性能低于纯python列表。我期待使用Cython优化所有正常的字符串函数。

例如,让我们取一个100,000个unicode字符串的numpy数组,其数据类型为unicode或object,每个字符串都为lowecase。

alist = ['JsDated', 'УКРАЇНА'] * 50000
arr_unicode = np.array(alist)
arr_object = np.array(alist, dtype='object')

%timeit np.char.lower(arr_unicode)
51.6 ms ± 1.99 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

使用列表理解同样快

%timeit [a.lower() for a in arr_unicode]
44.7 ms ± 2.69 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

对于对象数据类型,我们无法使用np.char。列表理解速度是3倍。

%timeit [a.lower() for a in arr_object]
16.1 ms ± 147 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

我知道如何在Cython中执行此操作的唯一方法是创建一个空对象数组,并在每次迭代时调用Python字符串方法lower

import numpy as np
cimport numpy as np
from numpy cimport ndarray

def lower(ndarray[object] arr):
    cdef int i
    cdef int n = len(arr)
    cdef ndarray[object] result = np.empty(n, dtype='object')
    for i in range(n):
        result[i] = arr[i].lower()
    return result

这会产生适度的改善

%timeit lower(arr_object)
11.3 ms ± 383 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

我尝试使用data ndarray属性直接访问内存,如下所示:

def lower_fast(ndarray[object] arr):
    cdef int n = len(arr)
    cdef int i
    cdef char* data = arr.data
    cdef int itemsize = arr.itemsize
    for i in range(n):
        # no idea here

我相信data是一个连续的内存,一个接一个地保存所有原始字节。访问这些字节非常快,似乎转换这些原始字节会使性能提高2个数量级。我找到了一个可能有效的tolower c ++函数,但我不知道如何用Cython挂钩它。

使用最快的方法更新(对unicode不起作用)

这是迄今为止我发现的另一个SO帖子中最快的方法。这通过data属性访问numpy内存视图来降低所有ascii字符的大小。我认为它会破坏其他字节在65到90之间的unicode字符。但速度非常好。

cdef int f(char *a, int itemsize, int shape):
    cdef int i
    cdef int num
    cdef int loc
    for i in range(shape * itemsize):
        num = a[i]
        print(num)
        if 65 <= num <= 90:
            a[i] +=32

def lower_fast(ndarray arr):
    cdef char *inp
    inp = arr.data
    f(inp, arr.itemsize, arr.shape[0])
    return arr

这比其他人快100倍,而且我正在寻找。

%timeit lower_fast(arr)
103 µs ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

1 个答案:

答案 0 :(得分:1)

这比我在我的机器上对列表的理解要快一点,但是如果你想要unicode支持,这可能是最快的方法。您需要apt-get install libunistring-dev或适合您的OS /包管理器的任何内容。

在某些C文件中,例如_lower.c

#include <stdlib.h>
#include <string.h>   
#include <unistr.h>
#include <unicase.h>

void _c_tolower(uint8_t  **s, uint32_t total_len) {
    size_t lower_len, s_len;
    uint8_t *s_ptr = *s, *lowered;
    while(s_ptr - *s < total_len) {
        s_len = u8_strlen(s_ptr);
        if (s_len == 0) {
            s_ptr += 1;
            continue;
        }
        lowered = u8_tolower(s_ptr, s_len, NULL, NULL, NULL, &lower_len);
        memcpy(s_ptr, lowered, lower_len);
        free(lowered);
        s_ptr += s_len;
    }
}

然后,在lower.pxd中你做

cdef extern from "_lower.c":
    cdef void _c_tolower(unsigned char **s, unsigned int total_len)

最后,在lower.pyx

cpdef void lower(ndarray arr):
    cdef unsigned char * _arr
    _arr = <unsigned char *> arr.data
    _c_tolower(&_arr, arr.shape[0] * arr.itemsize)

在我的笔记本电脑上,我上面的列表理解为46ms,此方法为37ms(lower_fast为0.8ms),所以它可能不值得,但我想我会输入它如果你想要一个如何将这样的东西挂钩到Cython的例子。

我不知道会有一些改进点会产生很大的不同:

    我认为
  • arr.data就像方阵一样? (我不知道,我不使用numpy做任何事情),并用\x00 s填充较短字符串的末尾。我太懒了,无法弄清楚如何让u8_tolower看过0,所以我只是手动快进它们(这就是if (s_len == 0)子句正在做的事情)。我怀疑对u8_tolower的一次调用要快几千次。
  • 我正在进行大量的释放/ memcpying。如果你聪明的话,你可以避免这种情况。
  • 认为这是每个小写unicode字符最多与其大写变体一样宽的情况,所以这不应该遇到任何段错误或缓冲区覆盖或只是重叠的子串问题,但不要请相信我的话。

不是真正的答案,但希望它有助于您进一步调查!

PS你会注意到这会降低就位,因此使用方式如下:

>>> alist = ['JsDated', 'УКРАЇНА', '道德經', 'Ну И йЕшШо'] * 2
>>> arr_unicode = np.array(alist)
>>> lower_2(arr_unicode)
>>> for x in arr_unicode:
...     print x
...
jsdated
україна
道德經
ну и йешшо
jsdated
україна
道德經
ну и йешшо

>>> alist = ['JsDated', 'УКРАЇНА'] * 50000
>>> arr_unicode = np.array(alist)
>>> ct = time(); x = [a.lower() for a in arr_unicode]; time() - ct;
0.046072959899902344
>>> arr_unicode = np.array(alist)
>>> ct = time(); lower_2(arr_unicode); time() - ct
0.037489891052246094

修改

DUH,你修改C函数看起来像这样

void _c_tolower(uint8_t  **s, uint32_t total_len) {
    size_t lower_len;
    uint8_t *lowered;

    lowered = u8_tolower(*s, total_len, NULL, NULL, NULL, &lower_len);
    memcpy(*s, lowered, lower_len);
    free(lowered);
}

然后它一气呵成。看起来更危险的是,lower_len遗留下来的旧数据可能比原始字符串更短......简而言之,这段代码完全是实验性的,仅用于说明目的,不要在生产中使用它可能会突然爆发。

无论如何,这样快〜40%:

>>> alist = ['JsDated', 'УКРАЇНА'] * 50000
>>> arr_unicode = np.array(alist)
>>> ct = time(); lower_2(arr_unicode); time() - ct
0.022463043975830078