在Python中插值3D数组

时间:2018-08-30 09:48:54

标签: python numpy scipy interpolation

我有一个3D NumPy数组,如下所示:

arr = np.empty((4,4,5))
arr[:] = np.nan
arr[0] = 1
arr[3] = 4

arr
>>> [[[ 1.  1.  1.  1.  1.]
      [ 1.  1.  1.  1.  1.]
      [ 1.  1.  1.  1.  1.]
      [ 1.  1.  1.  1.  1.]]

     [[ nan nan nan nan nan]
      [ nan nan nan nan nan]
      [ nan nan nan nan nan]
      [ nan nan nan nan nan]]

     [[ nan nan nan nan nan]
      [ nan nan nan nan nan]
      [ nan nan nan nan nan]
      [ nan nan nan nan nan]]

     [[ 4.  4.  4.  4.  4.]
      [ 4.  4.  4.  4.  4.]
      [ 4.  4.  4.  4.  4.]
      [ 4.  4.  4.  4.  4.]]]

我想沿axis=0进行插值,以便获得以下信息:

>>> [[[ 1.  1.  1.  1.  1.]
      [ 1.  1.  1.  1.  1.]
      [ 1.  1.  1.  1.  1.]
      [ 1.  1.  1.  1.  1.]]

     [[ 2.  2.  2.  2.  2.]
      [ 2.  2.  2.  2.  2.]
      [ 2.  2.  2.  2.  2.]
      [ 2.  2.  2.  2.  2.]]

     [[ 3.  3.  3.  3.  3.]
      [ 3.  3.  3.  3.  3.]
      [ 3.  3.  3.  3.  3.]
      [ 3.  3.  3.  3.  3.]]

     [[ 4.  4.  4.  4.  4.]
      [ 4.  4.  4.  4.  4.]
      [ 4.  4.  4.  4.  4.]
      [ 4.  4.  4.  4.  4.]]]

我一直在研究SciPy模块,似乎有一些方法可以在1D和2D阵列上执行此操作,但不是我需要的3D格式-尽管我可能错过了一些东西。

2 个答案:

答案 0 :(得分:2)

使用apply_along_axis的解决方案:

import numpy as np

def pad(data):
    good = np.isfinite(data)
    interpolated = np.interp(np.arange(data.shape[0]),
                             np.flatnonzero(good), 
                             data[good])
    return interpolated


arr = np.arange(6, dtype=float).reshape((3,2))
arr[1, 1] = np.nan
print(arr)

new = np.apply_along_axis(pad, 0, arr)
print(arr)
print(new)

输出:

[[ 0.  1.]
 [ 2. nan]
 [ 4.  5.]]

[[ 0.  1.]
 [ 2. nan]
 [ 4.  5.]]

[[0. 1.]
 [2. 3.]
 [4. 5.]]

[edit]提出的第一个解决方案:

this answer中的代码进行了一些修改:

import numpy as np
from scipy import interpolate

A = np.empty((4,4,5))
A[:] = np.nan
A[0] = 1
A[3] = 4

indexes = np.arange(A.shape[0])
good = np.isfinite(A).all(axis=(1, 2)) 

f = interpolate.interp1d(indexes[good], A[good],
                         bounds_error=False,
                         axis=0)

B = f(indexes)
print(B)

给予:

[[[1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]]

 [[2. 2. 2. 2. 2.]
  [2. 2. 2. 2. 2.]
  [2. 2. 2. 2. 2.]
  [2. 2. 2. 2. 2.]]

 [[3. 3. 3. 3. 3.]
  [3. 3. 3. 3. 3.]
  [3. 3. 3. 3. 3.]
  [3. 3. 3. 3. 3.]]

 [[4. 4. 4. 4. 4.]
  [4. 4. 4. 4. 4.]
  [4. 4. 4. 4. 4.]
  [4. 4. 4. 4. 4.]]]

仅当NaN都在同一切片上时,它才有效。 NaN孤立的切片将被忽略。

答案 1 :(得分:0)

从xdze2和先前的answer here提供的评论中,我想到了这一点:

import numpy as np

def pad(data):
    bad_indexes = np.isnan(data)
    good_indexes = np.logical_not(bad_indexes)
    good_data = data[good_indexes]
    interpolated = np.interp(bad_indexes.nonzero()[0], good_indexes.nonzero()[0], 
    good_data)
    data[bad_indexes] = interpolated
    return data

arr = np.empty((4,4,5))
arr[:] = np.nan

arr[0] = 25
arr[3] = 32.5

# Apply the pad method to each 0 axis
new = np.apply_along_axis(pad, 0, arr)

'pad'方法本质上应用插值,而np.apply_along_axis方法确保将其应用于3D数组。