我正在使用python创建一个应用程序来计算胶带的重叠(对分配器进行建模会在旋转的滚筒上应用产品)。
我有一个可以正常运行的程序,但是速度很慢。我正在寻找一种优化用于填充numpy数组的for
循环的解决方案。有人可以帮我向量化下面的代码吗?
import numpy as np
import matplotlib.pyplot as plt
# Some parameters
width = 264
bbddiam = 940
accuracy = 4 #2 points per pixel
drum = np.zeros(accuracy**2 * width * bbddiam).reshape((bbddiam * accuracy , width * accuracy))
# The "slow" function
def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
"""Masks a half of the array"""
to_return = np.zeros(drum.shape)
for index, v in np.ndenumerate(to_return):
if upper == True:
if index[0] * coef + intercept > index[1]:
to_return[index] = 1
else:
if index[0] * coef + intercept <= index[1]:
to_return[index] = 1
return to_return
def get_band(drum, coef, intercept, bandwidth):
"""Calculate a ribbon path on the drum"""
to_return = np.zeros(drum.shape)
t1 = line_mask(drum, coef, intercept + bandwidth / 2, upper=True)
t2 = line_mask(drum, coef, intercept - bandwidth / 2, upper=False)
to_return = t1 + t2
return np.where(to_return == 2, 1, 0)
single_band = get_band(drum, 1 / 10, 130, bandwidth=15)
# Visualize the result !
plt.imshow(single_band)
plt.show()
Numba确实为我的代码带来了奇迹,将运行时间从5.8s缩短到86ms(特别感谢@ Maarten-vd-Sande):
from numba import jit
@jit(nopython=True, parallel=True)
def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
...
仍然欢迎使用numpy更好的解决方案;-)
答案 0 :(得分:8)
这里根本不需要任何循环。实际上,您有两个不同的line_mask
函数。两者都不需要显式循环,但是仅通过在for
和if
中使用一对else
循环而不是{{1} }和if
在else
循环中,它被评估了很多次。
真正的numpythonic要做的事情是正确地矢量化代码,以在没有任何循环的情况下对整个数组进行操作。这是for
的向量化版本:
line_mask
将def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
"""Masks a half of the array"""
r = np.arange(drum.shape[0]).reshape(-1, 1)
c = np.arange(drum.shape[1]).reshape(1, -1)
comp = c.__lt__ if upper else c.__ge__
return comp(r * coef + intercept)
和r
的形状设置为c
和(m, 1)
,以使结果为(n, 1)
称为broadcasting,是numpy中向量化的主要内容。
更新后的(m, n)
的结果是布尔掩码(顾名思义),而不是浮点数组。这使其更小,并希望完全绕过浮动操作。现在,您可以重写line_mask
以使用遮罩而不是添加:
get_band
程序的其余部分应保持不变,因为这些功能保留了所有接口。
如果需要,可以用三行(仍然有些清晰)重写大部分程序:
def get_band(drum, coef, intercept, bandwidth):
"""Calculate a ribbon path on the drum"""
t1 = line_mask(drum, coef, intercept + bandwidth / 2, upper=True)
t2 = line_mask(drum, coef, intercept - bandwidth / 2, upper=False)
return t1 & t2