对numpy中的for循环进行矢量化处理,以计算管道胶带重叠

时间:2020-01-03 21:35:11

标签: python numpy for-loop vectorization numba

我正在使用python创建一个应用程序来计算胶带的重叠(对分配器进行建模会在旋转的滚筒上应用产品)。

我有一个可以正常运行的程序,但是速度很慢。我正在寻找一种优化用于填充numpy数组的for循环的解决方案。有人可以帮我向量化下面的代码吗?

import numpy as np
import matplotlib.pyplot as plt

# Some parameters
width = 264
bbddiam = 940
accuracy = 4 #2 points per pixel

drum = np.zeros(accuracy**2 * width * bbddiam).reshape((bbddiam * accuracy , width * accuracy))

# The "slow" function
def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
    """Masks a half of the array"""
    to_return = np.zeros(drum.shape)
    for index, v in np.ndenumerate(to_return):
        if upper == True:
            if index[0] * coef + intercept > index[1]:
                to_return[index] = 1
        else:
            if index[0] * coef + intercept <= index[1]:
                to_return[index] = 1
    return to_return


def get_band(drum, coef, intercept, bandwidth):
    """Calculate a ribbon path on the drum"""
    to_return = np.zeros(drum.shape)
    t1 = line_mask(drum, coef, intercept + bandwidth / 2, upper=True)
    t2 = line_mask(drum, coef, intercept - bandwidth / 2, upper=False)
    to_return = t1 + t2
    return np.where(to_return == 2, 1, 0)

single_band = get_band(drum, 1 / 10, 130, bandwidth=15)

# Visualize the result !
plt.imshow(single_band)
plt.show()

Numba确实为我的代码带来了奇迹,将运行时间从5.8s缩短到86ms(特别感谢@ Maarten-vd-Sande):

from numba import jit
@jit(nopython=True, parallel=True)
def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
    ...

仍然欢迎使用numpy更好的解决方案;-)

1 个答案:

答案 0 :(得分:8)

这里根本不需要任何循环。实际上,您有两个不同的line_mask函数。两者都不需要显式循环,但是仅通过在forif中使用一对else循环而不是{{1} }和ifelse循环中,它被评估了很多次。

真正的numpythonic要做的事情是正确地矢量化代码,以在没有任何循环的情况下对整个数组进行操作。这是for的向量化版本:

line_mask

def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy): """Masks a half of the array""" r = np.arange(drum.shape[0]).reshape(-1, 1) c = np.arange(drum.shape[1]).reshape(1, -1) comp = c.__lt__ if upper else c.__ge__ return comp(r * coef + intercept) r的形状设置为c(m, 1),以使结果为(n, 1)称为broadcasting,是numpy中向量化的主要内容。

更新后的(m, n)的结果是布尔掩码(顾名思义),而不是浮点数组。这使其更小,并希望完全绕过浮动操作。现在,您可以重写line_mask以使用遮罩而不是添加:

get_band

程序的其余部分应保持不变,因为这些功能保留了所有接口。

如果需要,可以用三行(仍然有些清晰)重写大部分程序:

def get_band(drum, coef, intercept, bandwidth):
    """Calculate a ribbon path on the drum"""
    t1 = line_mask(drum, coef, intercept + bandwidth / 2, upper=True)
    t2 = line_mask(drum, coef, intercept - bandwidth / 2, upper=False)
    return t1 & t2