如何限制numpy浮点计算精度

时间:2021-04-12 15:10:25

标签: python numpy optimization

我正在使用 numpy 来计算相机图像,这将由无符号整数灰度值表示。 我想限制浮点精度,以加快计算速度。 举个例子,假设我正在计算由高斯光束的强度分布形成的图像:

import numpy as np
import matplotlib.pyplot as plt

nx = 1000
ny = 1000
px = 5e-3

x = np.linspace(0, nx * px)
y = np.linspace(0, ny * px)

X, Y = np.meshgrid(x, y)

xc = x[-1] / 2
yc = y[-1] / 2
sigma = 1

gauss_profile = np.exp(-(np.square(X - xc) + np.square(Y - yc)) / sigma**2)
print(gauss_profile.dtype)

bitdepth = 12
gauss_profile *= 2**bitdepth - 1
camera_image = gauss_profile.astype(np.uint16)

#%% plot image
fig = plt.figure()
ax = fig.add_subplot(111)
grey_cmap = plt.get_cmap('gray')
im = ax.imshow(camera_image,
               cmap=grey_cmap,
               extent=(0, nx * px,
                       0, ny * px))
plt.xlabel('x (mm)')
plt.ylabel('y (mm)')
plt.colorbar(im)

有什么办法可以不以float64精度计算gauss_profile,而是使用足以获得所需灰度值的最小分辨率? 到目前为止,我之前尝试初始化数组并将其传递给 np.exp 调用中的 out 关键字,但这会导致 TypeError 或 ValueError 取决于 dtype。有没有其他方法可以加速这个计算?

1 个答案:

答案 0 :(得分:0)

也许尝试使用 dtype np.halfnp.single 设置您的 ndarray xy。这将限制浮点精度。但我不知道计算是否仍在进行。

<块引用>

np.half:半精度浮点数类型。 (float16)

<块引用>

np.single:单精度浮点数类型,兼容C float.(float32)

x = np.linspace(0, nx * px, dtype=np.half)
y = np.linspace(0, ny * px, dtype=np.half)

...