此问题类似于之前回答Fast interpolation over 3D array的问题,但无法解决我的问题。
我有一个尺寸为(时间,海拔,纬度,经度)的四维数组,标记为y.shape=(nt, nalt, nlat, nlon)
。 x是海拔高度并随(时间,纬度,经度)变化,这意味着x.shape = (nt, nalt, nlat, nlon)
。我想在每个(nt,nlat,nlon)的高度插值。插值的x_new应该是1d,而不是随着(时间,纬度,经度)改变。
我使用numpy.interp
,与scipy.interpolate.interp1d
相同,并考虑前一篇文章中的答案。我不能用那些答案来减少循环。
我只能这样做:
# y is a 4D ndarray
# x is a 4D ndarray
# new_y is a 4D array
for i in range(nlon):
for j in range(nlat):
for k in range(nt):
y_new[k,:,j,i] = np.interp(new_x, x[k,:,j,i], y[k,:,j,i])
这些循环使得这种插值计算速度太慢。有人会有好主意吗?帮助将受到高度赞赏。
答案 0 :(得分:1)
这是我使用numba的解决方案,它的速度提高了约3倍。
首先创建测试数据,x
需要按升序排列:
import numpy as np
rows = 200000
cols = 66
new_cols = 69
x = np.random.rand(rows, cols)
x.sort(axis=-1)
y = np.random.rand(rows, cols)
nx = np.random.rand(new_cols)
nx.sort()
在numpy中进行200000次interp:
%%time
ny = np.empty((x.shape[0], len(nx)))
for i in range(len(x)):
ny[i] = np.interp(nx, x[i], y[i])
我使用合并方法而不是二元搜索方法,因为nx
是有序的,nx
的长度与x
大致相同。
interp()
使用二进制搜索,时间复杂度为O(len(nx)*log2(len(x))
O(len(nx) + len(x))
这是numba代码:
import numba
@numba.jit("f8[::1](f8[::1], f8[::1], f8[::1], f8[::1])")
def interp2(x, xp, fp, f):
n = len(x)
n2 = len(xp)
j = 0
i = 0
while x[i] <= xp[0]:
f[i] = fp[0]
i += 1
slope = (fp[j+1] - fp[j])/(xp[j+1] - xp[j])
while i < n:
if x[i] >= xp[j] and x[i] < xp[j+1]:
f[i] = slope*(x[i] - xp[j]) + fp[j]
i += 1
continue
j += 1
if j + 1 == n2:
break
slope = (fp[j+1] - fp[j])/(xp[j+1] - xp[j])
while i < n:
f[i] = fp[n2-1]
i += 1
@numba.jit("f8[:, ::1](f8[::1], f8[:, ::1], f8[:, ::1])")
def multi_interp(x, xp, fp):
nrows = xp.shape[0]
f = np.empty((nrows, x.shape[0]))
for i in range(nrows):
interp2(x, xp[i, :], fp[i, :], f[i, :])
return f
然后调用numba函数:
%%time
ny2 = multi_interp(nx, x, y)
检查结果:
np.allclose(ny, ny2)
在我的电脑上,时间是:
python version: 3.41 s
numba version: 1.04 s
此方法需要一个数组,最后一个轴是interp()
的轴。