3D原点x的3D阵列快速插值

时间:2014-01-18 02:09:36

标签: python arrays numpy 3d

此问题类似于之前回答Fast interpolation over 3D array的问题,但无法解决我的问题。

我有一个尺寸为(时间,海拔,纬度,经度)的四维数组,标记为y.shape=(nt, nalt, nlat, nlon)。 x是海拔高度并随(时间,纬度,经度)变化,这意味着x.shape = (nt, nalt, nlat, nlon)。我想在每个(nt,nlat,nlon)的高度插值。插值的x_new应该是1d,而不是随着(时间,纬度,经度)改变。

我使用numpy.interp,与scipy.interpolate.interp1d相同,并考虑前一篇文章中的答案。我不能用那些答案来减少循环。

我只能这样做:

# y is a 4D ndarray
# x is a 4D ndarray
# new_y is a 4D array
for i in range(nlon):
    for j in range(nlat):
        for k in range(nt):
            y_new[k,:,j,i] = np.interp(new_x, x[k,:,j,i], y[k,:,j,i])

这些循环使得这种插值计算速度太慢。有人会有好主意吗?帮助将受到高度赞赏。

1 个答案:

答案 0 :(得分:1)

这是我使用numba的解决方案,它的速度提高了约3倍。

首先创建测试数据,x需要按升序排列:

import numpy as np
rows = 200000
cols = 66
new_cols = 69
x = np.random.rand(rows, cols)
x.sort(axis=-1)
y = np.random.rand(rows, cols)
nx = np.random.rand(new_cols)
nx.sort() 

在numpy中进行200000次interp:

%%time
ny = np.empty((x.shape[0], len(nx)))
for i in range(len(x)):
    ny[i] = np.interp(nx, x[i], y[i])

我使用合并方法而不是二元搜索方法,因为nx是有序的,nx的长度与x大致相同。

  • interp()使用二进制搜索,时间复杂度为O(len(nx)*log2(len(x))
  • 合并方法:时间复杂度为O(len(nx) + len(x))

这是numba代码:

import numba

@numba.jit("f8[::1](f8[::1], f8[::1], f8[::1], f8[::1])")
def interp2(x, xp, fp, f):
    n = len(x)
    n2 = len(xp)
    j = 0
    i = 0
    while x[i] <= xp[0]:
        f[i] = fp[0]
        i += 1

    slope = (fp[j+1] - fp[j])/(xp[j+1] - xp[j])        
    while i < n:
        if x[i] >= xp[j] and x[i] < xp[j+1]:
            f[i] = slope*(x[i] - xp[j]) + fp[j]
            i += 1
            continue
        j += 1
        if j + 1 == n2:
            break
        slope = (fp[j+1] - fp[j])/(xp[j+1] - xp[j])   

    while i < n:
        f[i] = fp[n2-1]
        i += 1

@numba.jit("f8[:, ::1](f8[::1], f8[:, ::1], f8[:, ::1])")
def multi_interp(x, xp, fp):
    nrows = xp.shape[0]
    f = np.empty((nrows, x.shape[0]))
    for i in range(nrows):
        interp2(x, xp[i, :], fp[i, :], f[i, :])
    return f

然后调用numba函数:

%%time
ny2 = multi_interp(nx, x, y)

检查结果:

np.allclose(ny, ny2)

在我的电脑上,时间是:

python version: 3.41 s
numba version: 1.04 s

此方法需要一个数组,最后一个轴是interp()的轴。