在Cython中高效的numpy图像循环

时间:2016-06-24 09:53:42

标签: python opencv numpy cython

我正在尝试优化使用opencv和numpy查看图像中笔画宽度的Python程序。到目前为止,我已经设法通过使用Cython进行静态打字来加速4倍,但我想应该可以加快速度。

当我查看我的HTML报告时,看起来内部循环,即while循环,并未转换为c循环。但它只使用静态类型变量?

同样,数组的简单副本“outImg = img”标记为黄色,虽然它们都输入为“np.ndarray [SMALL_INT,ndim = 3]”,为什么会这样?

我的完整HTML报告位于:http://da-robotteknik.se/shared/swt.html

我的代码看起来像这样,每个图像帧调用一次swt函数:

import numpy as np
cimport numpy as np
import cv2

cimport cython

ctypedef unsigned char SMALL_INT
ctypedef np.float32_t FLOAT

cdef extern from "math.h":
    double sqrt(double m)

cdef np.ndarray[FLOAT, ndim=2] gradient(np.ndarray[SMALL_INT, ndim=2] img, char dxdy):
    cdef np.ndarray[FLOAT, ndim=2] d
    d = cv2.Sobel(img,  cv2.CV_32F, dxdy==0, dxdy==1,  ksize = 1,  delta = 0.5,  scale = 0.01)
    return d

@cython.boundscheck(False)
def swt(np.ndarray[SMALL_INT, ndim=3] img, np.ndarray[SMALL_INT, ndim=2] mask, int thStart, int thStop, int pDivider, int lMin, int lMax):
    cdef np.ndarray[SMALL_INT, ndim=3] outImg
    cdef np.ndarray[FLOAT, ndim=2] dRow, dCol, gradMag
    cdef np.ndarray[SMALL_INT, ndim=2] imgGray, dRowUint, dRowPosMask, startSWT, stopSWT, gradMagUint
    cdef int h,w,ch, points, row, col, i, l, r, c
    cdef float dirRow, dirCol, dirMag
    cdef np.ndarray[long, ndim=1] startPointsRow, startPointsCol        

    t1 = cv2.getTickCount();    

    outImg = img
    h,  w,  ch = np.shape(img)

    # Calculate gradients
    imgGray = cv2.cvtColor(img,  cv2.COLOR_BGR2GRAY)
    dCol = gradient(imgGray, 0)
    dRow = gradient(imgGray, 1)

    # Make a mask for pos row-gradient (This is startpoint in the SWT)
    dRowUint = cv2.convertScaleAbs(dRow,  alpha=255)
    ret,  dRowPosMask = cv2.threshold(dRowUint, 128, 255, cv2.THRESH_BINARY)

    # Calculate magnitude of gradients
    gradMag = cv2.absdiff(dRow, 0.5) + cv2.absdiff(dCol, 0.5)
    gradMagUint = cv2.convertScaleAbs(gradMag,  alpha=255)

    # Find suitable startpoints for the SWT
    startSWT = np.bitwise_and(gradMagUint, dRowPosMask)
    startSWT = np.bitwise_and(startSWT,  mask)
    ret,  startSWT = cv2.threshold(startSWT,  thStart,  255,  cv2.THRESH_TOZERO)
    startPointsRow,  startPointsCol = np.nonzero(startSWT)

    # Find stop points
    stopSWT = np.bitwise_and(gradMagUint, np.invert(dRowPosMask))
    stopSWT = np.bitwise_and(stopSWT,  mask)

    pointsTuple = np.shape(startPointsRow)
    points = pointsTuple[0];

    t2 = cv2.getTickCount();

    # Step until stopSWT > th
    for i in range(points):
        # Find start pos
        row = startPointsRow[i]
        col = startPointsCol[i]
        # Find direction
        dirRow = -dRow[row][col]+0.5
        dirCol = -dCol[row][col]+0.5
        dirMag = sqrt(dirRow*dirRow + dirCol*dirCol)
        dirRow /= dirMag
        dirCol /= dirMag
        # Step until stop found or l > 100
        l = 1
        r = <int>(row + l*dirRow)
        c = <int>(col + l*dirCol)
        while(r<h and r>=0 and c<w and c>=0 and stopSWT[r][c] < thStop and l < lMax):
            l  += 1;
            r = <int>(row + l*dirRow)
            c = <int>(col + l*dirCol)
        if (r<h and r>=0 and c<w and c>=0 and l < lMax and l > lMin):
            cv2.line(outImg,  (col, row),  (c, r), (0, 255, 0))
        elif (l > lMax):
            cv2.line(outImg,  (col, row),  (c, r), (255, 0, 0))
        elif (l < lMin):
            cv2.line(outImg,  (col, row),  (c, r), (0, 0, 255))

    t3 = cv2.getTickCount();

    print((t2-t1)/cv2.getTickFrequency())
    print((t3-t2)/cv2.getTickFrequency())        

    return outImg

0 个答案:

没有答案