二维线性插值:数据和插值点

时间:2017-07-05 09:51:39

标签: python numpy matplotlib linear-interpolation mplot3d

考虑这个y(x)函数:

enter image description here

我们可以在文件中生成这些分散的点:dataset_1D.dat

# x   y   
0   0
1   1
2   0
3   -9
4   -32

以下是这些点的1D插值代码:

  1. 加载此分散点

  2. 创建x_mesh

  3. 执行1D插值

  4. 代码:

    import numpy as np
    from scipy.interpolate import interp2d, interp1d, interpnd
    import matplotlib.pyplot as plt
    
    
    # Load the data:    
    x, y  = np.loadtxt('./dataset_1D.dat', skiprows = 1).T
    
    # Create the function Y_inter for interpolation:
    Y_inter = interp1d(x,y)
    
    # Create the x_mesh:    
    x_mesh = np.linspace(0, 4, num=10)
    print x_mesh
    
    # We calculate the y-interpolated of this x_mesh :   
    Y_interpolated = Y_inter(x_mesh)
    print Y_interpolated
    
    # plot:
    
    plt.plot(x_mesh, Y_interpolated, "k+")
    plt.plot(x, y, 'ro')
    plt.legend(['Linear 1D interpolation', 'data'], loc='lower left',  prop={'size':12})
    plt.xlim(-0.1, 4.2)
    plt.grid()
    plt.ylabel('y')
    plt.xlabel('x')
    plt.show()
    

    这包括以下内容:

    enter image description here

    现在,考虑这个z(x,y)函数:

    enter image description here

    我们可以在文件中生成这些分散的点:dataset_2D.dat

    # x    y    z
    0   0   0
    1   1   0
    2   2   -4
    3   3   -18
    4   4   -48
    

    在这种情况下,我们必须执行2D插值:

    import numpy as np
    from scipy.interpolate import interp1d, interp2d, interpnd
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    
    # Load the data:
    x, y, z  = np.loadtxt('./dataset_2D.dat', skiprows = 1).T
    
    # Create the function Z_inter for interpolation:
    Z_inter = interp2d(x, y, z)
    
    # Create the x_mesh and y_mesh :
    x_mesh = np.linspace(1.0, 4, num=10)
    y_mesh = np.linspace(1.0, 4, num=10)
    print x_mesh
    print y_mesh
    
    # We calculate the z-interpolated of this x_mesh and y_mesh :
    Z_interpolated = Z_inter(x_mesh, y_mesh)
    print Z_interpolated
    print type(Z_interpolated)
    print Z_interpolated.shape
    
    # plot: 
    fig = plt.figure()
    ax = Axes3D(fig)
    ax.scatter(x, y, z, c='r', marker='o')
    plt.legend(['data'], loc='lower left',  prop={'size':12})
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    
    plt.show()
    

    这包括以下内容:

    enter image description here

    其中散点数据再次以红点显示,与2D图一致。

    1. 我不知道如何解释Z_interpolated结果:

      根据上述代码的印刷线, Z_interpolated是一个n维的numpy数组,形状(10,10)。换句话说,具有10行和10列的2D矩阵。

    2. 我希望z[i]x_mesh[i]的每个值都有一个插值的y_mesh[i]值。为什么我没有收到这个值?

      1. 我怎样才能在3D绘图中绘制插值数据(就像2D绘图中的黑色十字架一样)?

2 个答案:

答案 0 :(得分:1)

Z_interpolated的解释:您的1-D x_meshy_mesh定义了mesh on which to interpolate。因此,您的2-D插值返回z是具有与np.meshgrid(x_mesh, y_mesh)匹配的形状(len(y),len(x))的2D数组。如您所见,您的z [i,i]而不是z [i]是x_mesh[i]y_mesh[i]的预期值。它只是有更多,网格上的所有值。

显示所有插值数据的潜在图表:

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import interp2d

# Your original function
x = y = np.arange(0, 5, 0.1)
xx, yy = np.meshgrid(x, y)
zz = 2 * (xx ** 2) - (xx ** 3) - (yy ** 2)

# Your scattered points
x = y = np.arange(0, 5)
z = [0, 0, -4, -18, -48]

# Your interpolation
Z_inter = interp2d(x, y, z)
x_mesh = y_mesh = np.linspace(1.0, 4, num=10)
Z_interpolated = Z_inter(x_mesh, y_mesh)

fig = plt.figure()
ax = fig.gca(projection='3d')
# Plot your original function
ax.plot_surface(xx, yy, zz, color='b', alpha=0.5)
# Plot your initial scattered points
ax.scatter(x, y, z, color='r', marker='o')
# Plot your interpolation data
X_real_mesh, Y_real_mesh = np.meshgrid(x_mesh, y_mesh)
ax.scatter(X_real_mesh, Y_real_mesh, Z_interpolated, color='g', marker='^')
plt.show()

enter image description here

答案 1 :(得分:1)

您需要两个插值步骤。第一个在y数据之间插值。第二个在z数据之间进行插值。然后使用两个插值数组绘制x_mesh

x_mesh = np.linspace(0, 4, num=16)

yinterp = np.interp(x_mesh, x, y)
zinterp = np.interp(x_mesh, x, z)

ax.scatter(x_mesh, yinterp, zinterp, c='k', marker='s')

在下面的完整示例中,我还添加了y方向的一些变化,以使解决方案更加通用。

u = u"""# x    y    z
0   0   0
1   3   0
2   9   -4
3   16   -18
4   32   -48"""

import io
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Load the data:
x, y, z  = np.loadtxt(io.StringIO(u), skiprows = 1, unpack=True)

x_mesh = np.linspace(0, 4, num=16)

yinterp = np.interp(x_mesh, x, y)
zinterp = np.interp(x_mesh, x, z)

fig = plt.figure()
ax = Axes3D(fig)
ax.scatter(x_mesh, yinterp, zinterp, c='k', marker='s')
ax.scatter(x, y, z, c='r', marker='o')
plt.legend(['data'], loc='lower left',  prop={'size':12})
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')

plt.show()

enter image description here

对于使用scipy.interpolate.interp1d,解决方案基本相同:

u = u"""# x    y    z
0   0   0
1   3   0
2   9   -4
3   16   -18
4   32   -48"""

import io
import numpy as np
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Load the data:
x, y, z  = np.loadtxt(io.StringIO(u), skiprows = 1, unpack=True)

x_mesh = np.linspace(0, 4, num=16)

fy = interp1d(x, y, kind='cubic')
fz = interp1d(x, z, kind='cubic')

fig = plt.figure()
ax = Axes3D(fig)
ax.scatter(x_mesh, fy(x_mesh), fz(x_mesh), c='k', marker='s')
ax.scatter(x, y, z, c='r', marker='o')
plt.legend(['data'], loc='lower left',  prop={'size':12})
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')

plt.show()