加快numpy条件替换

时间:2018-02-08 17:29:51

标签: python numpy opencv

我正在努力加快内心循环:

import cv2
import datetime
gt = cv2.imread("image.png",0)

start = datetime.datetime.now()
for i in range(gt.shape[0]):
    for j in range(gt.shape[1]):
        if(gt[i,j] == 0): continue
        limit = min(60,((gt.shape[1]-j)))
        for d in range(limit): # <------- this one
            if(gt[i,j] < gt[i,j+d] - d):
                gt[i,j] = 0
                break
print(datetime.datetime.now() - start)

有没有办法使用numpy内置运算符来重写它?目前它非常慢,每张图像都是46秒。我已经尝试过类似的东西:

gt[ gt[i,j+range(d)] - range(d)]=0

但当然它不起作用,因为你不能用列表加总int。

2 个答案:

答案 0 :(得分:3)

ds可以用常量np.arange(60)数组表示,gt[i,j+d]gt[i, j:j+d]表示。

import numpy as np
ds = np.arange(60)
for i in range(gt.shape[0]):
    for j in range(gt.shape[1]):
        if(gt[i, j] == 0):
            continue
        limit = min(60, ((gt.shape[1]-j)))
        if np.any(gt[i, j] < (gt[i, j:j+limit]-ds[:limit])):
            gt[i, j] = 0

如果您没有内存问题,可以进一步对其进行矢量化。例如:

mask = np.zeros(gt.shape, dtype="bool")
for dist in range(1, 60):
    diffs = gt[:, :-dist] < (gt[:, dist:] - dist)
    mask[:, :-dist] |= diffs
gt[mask] = 0

使用无符号整数时,diffs应使用以下函数计算:

diffs = (gt2[:, :-dist] < (gt2[:, dist:] - dist)) & (gt2[:, dist:]>dist)

防止溢出问题。

给出约。在200x200pix图像上加速200倍。

答案 1 :(得分:0)

gt = np.array([[1, 1, 1], [1.5, 2.5, 2.5], [2, 3.5, 4]])
limit = 2

m = np.zeros_like(gt, dtype=bool)
for d in range(1, limit + 1):
    if d >= gt.shape[0]:
        break
    shifted = np.roll(gt, shift=-d, axis=0)
    shifted[-d:] = np.nan
    m = m | (gt < shifted - d)
gt[m] = 0

print(gt)

输出

array([[1. , 0. , 0. ],
       [1.5, 2.5, 0. ],
       [2. , 3.5, 4. ]])