Cython循环遍历像素仍然很慢

时间:2018-12-11 02:33:39

标签: python numpy cython

我在这里在常规python代码之间没有速度差异。它说瓶颈是html文件中的最后两行代码。有什么办法解决吗?

我想做的是遍历像素并向其中添加rgb值低于210的坐标。

from PIL import Image
import numpy as np
import time
import cython
import cv2

filename = "/home/user/PycharmProjects/Testing/files/file001.png"
image = Image.open(filename)
size = width, height = image.size
image_data = np.asarray(image)

cdef list list_text = []

@cython.boundscheck(False)
cpdef get_image_data():
    cdef int y, x
    for y in range(1683):
        for x in range(1240):
            if image_data[y, x] < 210:
                list_text.append([x, y])

3 个答案:

答案 0 :(得分:2)

循环没有问题,但是将列表追加到列表非常慢。为避免这种情况,您可以为数据分配足够大的数组并在之后缩小它(或将数据复制到具有所需大小的数组中),也可以使用std:vector来实现功能。

在这个答案中,我使用Numba,因为我对高性能Cython编码并不熟悉,但是Cython实现应该很简单。 Numba的list和tuple的内部表示形式也很有限,但是我不知道Cython中是否可以使用它们。

示例

import numpy as np
import numba as nb

@nb.njit()
def get_image_data_arr(image_data):
  array_text = np.empty((image_data.shape[0]*image_data.shape[1],2),dtype=np.int64)
  ii=0
  for y in range(image_data.shape[0]):
    for x in range(image_data.shape[1]):
      if image_data[y, x] < 210:
        array_text[ii,0]=x
        array_text[ii,1]=y
        ii+=1
  return array_text[:ii,:]

@nb.njit()
def get_image_data(image_data):
  list_text = []
  for y in range(image_data.shape[0]):
    for x in range(image_data.shape[1]):
      if image_data[y, x] < 210:
         #appending lists
         list_text.append([x, y])
         #appending tuples
         #list_text.append((x, y))
  return list_text

时间

所有计时都没有编译开销(计时中不包括对函数的第一次调用)。

#Create some data
image_data=np.random.rand(1683*1240).reshape(1683,1240)*255
image_data=image_data.astype(np.uint8)


get_image_data (Pure Python)                   : 3.4s
get_image_data (naive Numba, appending lists)  : 1.1s
get_image_data (naive Numba, appending tuples) : 0.3s
get_image_data_arr:                            : 0.012s
np.argwhere(image_data<210)                    : 0.035s

答案 1 :(得分:1)

我建议如下使用Numpy的value.*函数:

argwhere()

看起来像这样:

import numpy as np

# Create a starting image
im = np.arange(0,255,16).reshape(4,4)                                                      

现在查找所有小于210的元素的坐标:

array([[  0,  16,  32,  48],
       [ 64,  80,  96, 112],
       [128, 144, 160, 176],
       [192, 208, 224, 240]])

看起来像这样:

np.argwhere(im<210)  

答案 2 :(得分:0)

好的,所以我将其修复。现在,我要弄清楚如何将那些像素坐标保存到二维数组中。因为如果我添加python样式,它会使整个过程变慢。有什么建议么?我也不希望再次返回image_data。

有趣的是,此代码比python快28000倍!我期望速度提高100倍,而不是很多。

@cython.boundscheck(False)
cpdef const unsigned char[:, :] get_image_data(const unsigned char[:, :] image_data):

cdef int x, y
cdef list list_text = []

for y in range(1683):
    for x in range(1240):
        if image_data[y, x] < 210:
            pass
return image_data