interpolate.griddata仅使用一个核心

时间:2018-09-07 18:04:02

标签: python numpy parallel-processing scipy interpolation

我正在插值2d numpy数组以填充用NaN标记的缺失值。以下代码有效,但仅使用一个内核。 我可以使用所有更好的功能来利用我拥有的全部24个内核吗?

x = np.arange(0, array.shape[1])
y = np.arange(0, array.shape[0])
#mask invalid values
array = np.ma.masked_invalid(array)
xx, yy = np.meshgrid(x, y)
#get only the valid values
x1 = xx[~array.mask]
y1 = yy[~array.mask]
newarr = array[~array.mask]

GD1 = interpolate.griddata((x1, y1), newarr.ravel(),
                      (xx, yy),
                         method='cubic')

1 个答案:

答案 0 :(得分:0)

我认为您可以使用dask来做到这一点。我不太熟悉dask,但这是一个开始:

import numpy as np
from scipy import interpolate
import dask.array as da
import matplotlib.pyplot as plt
from dask import delayed

# create data with random missing entries
ar_size = 2000
chunk_size = 500
z_array = np.ones((ar_size, ar_size))
z_array[np.random.randint(0, ar_size-1, 50),
      np.random.randint(0, ar_size-1, 50)]= np.nan

# XY coords
x = np.linspace(0, 3, z_array.shape[1])
y = np.linspace(0, 3, z_array.shape[0])

# gen sin wave for testing
z_array = z_array * np.sin(x)
# prove there are nans in the dataset
assert np.isnan(np.sum(z_array))

xx, yy = np.meshgrid(x, y)
print("global x.size: ", xx.size)

# make dask arrays
dask_xyz = da.from_array((xx, yy, z_array), chunks=(3, chunk_size, "auto"), name="dask_all")
dask_xx = dask_xyz[0,:,:]
dask_yy = dask_xyz[1,:,:]
dask_zz = dask_xyz[2,:,:]

# select only valid values
dask_valid_y1 = dask_yy[~da.isnan(dask_zz)]
dask_valid_x1 = dask_xx[~da.isnan(dask_zz)]
dask_newarr = dask_zz[~da.isnan(dask_zz)]

def gd_wrapped(x1, y1, newarr, xx, yy):
    # note: linear and cubic griddata impl do not extrapolate
    # and therefore fail near the boundaries... see RBF interp instead
    print("local x.size: ", x1.size)
    gd_zz = interpolate.griddata((x1, y1), newarr.ravel(),
                               (xx, yy),
                               method='nearest')
    return gd_zz

def rbf_wrapped(x1, y1, newarr, xx, yy):
    rbf_interpolant = interpolate.Rbf(x1, y1, newarr, function='linear')
    return rbf_interpolant(xx, yy)

# interpolate
# gd_chunked = [delayed(rbf_wrapped)(x1, y1, newarr, xx, yy) for \
gd_chunked = [delayed(gd_wrapped)(x1, y1, newarr, xx, yy) for \
            x1, y1, newarr, xx, yy \
            in \
            zip(dask_valid_x1.to_delayed().flatten(),
                dask_valid_y1.to_delayed().flatten(),
                dask_newarr.to_delayed().flatten(),
                dask_xx.to_delayed().flatten(),
                dask_yy.to_delayed().flatten())]
gd_out = delayed(da.concatenate)(gd_chunked, axis=0)
gd_out.visualize("dask_par.png")
gd1 = np.array(gd_out.compute())
print(gd1)
assert gd1.shape == (ar_size, ar_size)
print(gd1.shape)
plt.figure()
plt.imshow(gd1)
plt.savefig("dask_par_sin.png")

# prove we have no more nans in the data
assert ~np.isnan(np.sum(gd1))

此实现存在一些问题。 Griddata无法外推,因此nans在块边界上是一个问题。您可能可以使用一些重叠的单元格解决此问题。作为权宜之计,您可以使用method ='nearest'或尝试radial basis function interpolation