如何在NumPy中构建三线性插值的查找表?

时间:2013-08-11 10:01:38

标签: python numpy interpolation lookup linear-interpolation

以下提取是一个500行表,我正在尝试为其构建一个numpy查找函数。我的问题是这些值是非线性的。

用户输入densityvolumecontent。所以功能将是:

def capacity_lookup(density, volume, content:

例如,典型的用户条目为capacity_lookup (47, 775, 41.3)。该函数应在45和50之间以及密度700和800以及内容40和45之间进行插值。

表格摘录是:

Volume  Density        Content 
                <30 35  40  45  50>=
45.0    <=100   0.1 1.8 0.9 2.0 0.3
45.0    200     1.5 1.6 1.4 2.4 3.0
45.0    400     0.4 2.1 0.9 1.8 2.5
45.0    600     1.3 0.8 0.2 1.7 1.9
45.0    800     0.6 0.9 0.8 0.4 0.2
45.0    1000    0.3 0.8 0.5 0.3 1.0
45.0    1200    0.6 0.0 0.6 0.2 0.2
45.0    1400    0.6 0.4 0.3 0.7 0.1
45.0    >=1600  0.3 0.0 0.6 0.1 0.3
50.0    <=100   0.1 0.0 0.5 0.9 0.2
50.0    200     1.3 0.4 0.8 0.2 2.7
50.0    400     0.4 0.1 0.7 1.3 1.7
50.0    600     0.8 0.7 0.1 1.2 1.6
50.0    800     0.5 0.3 0.4 0.2 0.0
50.0    1000    0.2 0.4 0.4 0.2 0.3
50.0    1200    0.4 0.0 0.0 0.2 0.0
50.0    1400    0.0 0.3 0.1 0.5 0.1
50.0    >=1600  0.1 0.0 0.0 0.0 0.2
55.0    <=100   0.0 0.0 0.4 0.6 0.1
55.0    200     0.8 0.3 0.7 0.1 1.2
55.0    400     0.3 0.1 0.3 1.1 0.7
55.0    600     0.4 0.3 0.0 0.6 0.1
55.0    800     0.0 0.0 0.0 0.2 0.0
55.0    1000    0.2 0.1 0.2 0.1 0.3
55.0    1200    0.1 0.0 0.0 0.1 0.0
55.0    1400    0.0 0.2 0.0 0.2 0.1
55.0    >=1600  0.0 0.0 0.0 0.0 0.1

问题

如何存储500行表,以便对非线性数据进行插值并根据用户输入获得正确的值?

澄清

  1. 如果用户输入以下向量(775,47,41.3),程序应在以下四个向量之间返回一个插值:45.0, 600, 0.2, 1.745.0, 800, 0.8, 0.450.0, 600, 0.1, 1.250.0, 800, 0.4, 0.2
  2. 假设数据将作为您设计的numpy数组从数据库中提取

2 个答案:

答案 0 :(得分:2)

我找到的第一个难点是<=>=,我可以处理Density的肢体重复,并更改非常接近的虚拟值99的值和1601,这不会影响插值。

Volume  Density        Content 
                <30 35  40  45  50>=
45.0     99   0.1 1.8 0.9 2.0 0.3
45.0    100   0.1 1.8 0.9 2.0 0.3
45.0    200     1.5 1.6 1.4 2.4 3.0
45.0    400     0.4 2.1 0.9 1.8 2.5
45.0    600     1.3 0.8 0.2 1.7 1.9
45.0    800     0.6 0.9 0.8 0.4 0.2
45.0    1000    0.3 0.8 0.5 0.3 1.0
45.0    1200    0.6 0.0 0.6 0.2 0.2
45.0    1400    0.6 0.4 0.3 0.7 0.1
45.0    1600  0.3 0.0 0.6 0.1 0.3
45.0    1601  0.3 0.0 0.6 0.1 0.3
50.0     99   0.1 0.0 0.5 0.9 0.2
50.0    100   0.1 0.0 0.5 0.9 0.2
50.0    200     1.3 0.4 0.8 0.2 2.7
50.0    400     0.4 0.1 0.7 1.3 1.7
50.0    600     0.8 0.7 0.1 1.2 1.6
50.0    800     0.5 0.3 0.4 0.2 0.0
50.0    1000    0.2 0.4 0.4 0.2 0.3
50.0    1200    0.4 0.0 0.0 0.2 0.0
50.0    1400    0.0 0.3 0.1 0.5 0.1
50.0    1600  0.1 0.0 0.0 0.0 0.2
50.0    1601  0.1 0.0 0.0 0.0 0.2
55.0     99   0.0 0.0 0.4 0.6 0.1
55.0    100   0.0 0.0 0.4 0.6 0.1
55.0    200     0.8 0.3 0.7 0.1 1.2
55.0    400     0.3 0.1 0.3 1.1 0.7
55.0    600     0.4 0.3 0.0 0.6 0.1
55.0    800     0.0 0.0 0.0 0.2 0.0
55.0    1000    0.2 0.1 0.2 0.1 0.3
55.0    1200    0.1 0.0 0.0 0.1 0.0
55.0    1400    0.0 0.2 0.0 0.2 0.1
55.0    1600  0.0 0.0 0.0 0.0 0.1
55.0    1601  0.0 0.0 0.0 0.0 0.1

然后,正如@Jaime已经指出的那样,你必须找到8个顶点才能进行三线性插值。

以下算法将为您提供以下几点:

import numpy as np
def get_8_points(filename, vi, di, ci):
    a = np.loadtxt(filename, skiprows=2)
    vol = a[:,0].repeat(a.shape[1]-2).reshape(-1,)
    den = a[:,1].repeat(a.shape[1]-2).reshape(-1,)
    #FIXME maybe you have to change the next line
    con = np.tile(np.array([30., 35., 40., 45., 50.]),a.shape[0]).reshape(-1,)
    #
    val = a[:,2:].reshape(a.shape[0]*5).reshape(-1,)

    u = np.unique(vol)
    diff = np.absolute(u-vi)
    vols = u[diff.argsort()][:2]

    u = np.unique(den)
    diff = np.absolute(u-di)
    dens = u[diff.argsort()][:2]

    u = np.unique(con)
    diff = np.absolute(u-ci)
    cons = u[diff.argsort()][:2]

    check = np.in1d(vol,vols) & np.in1d(den,dens) & np.in1d(con,cons)

    points = np.vstack((vol[check], den[check], con[check], val[check]))

    return points.T

使用您的示例:

vi, di, ci = 47, 775, 41.3
points = get_8_points(filename, vi, di, ci)
#array([[  4.50e+01,   6.00e+02,   4.00e+01,   2.00e-01],
#       [  4.50e+01,   6.00e+02,   4.50e+01,   1.70e+00],
#       [  4.50e+01,   8.00e+02,   4.00e+01,   8.00e-01],
#       [  4.50e+01,   8.00e+02,   4.50e+01,   4.00e-01],
#       [  5.00e+01,   6.00e+02,   4.00e+01,   1.00e-01],
#       [  5.00e+01,   6.00e+02,   4.50e+01,   1.20e+00],
#       [  5.00e+01,   8.00e+02,   4.00e+01,   4.00e-01],
#       [  5.00e+01,   8.00e+02,   4.50e+01,   2.00e-01]])

现在您可以执行三线性插值...

答案 1 :(得分:2)

为了补充Saullo的答案,以下是如何进行三线性插值。你基本上将立方体插入一个正方形,然后将正方形插入一个段,并将该段转换为一个点。插值顺序不会改变最终结果。 Saullo的编号方案已经是正确的:基本顶点是数字0,最后一个维度增加1到顶点数,倒数第二个增加2,第一个维度增加4.所以从他的顶点返回函数,您可以执行以下操作:

coords = np.array([47, 775, 41.3])
ndim = len(coords)
# You would get this with a call to:
# vertices = get_8_points(filename, *coords)
vertices = np.array([[  4.50e+01,   6.00e+02,   4.00e+01,   2.00e-01],
                     [  4.50e+01,   6.00e+02,   4.50e+01,   1.70e+00],
                     [  4.50e+01,   8.00e+02,   4.00e+01,   8.00e-01],
                     [  4.50e+01,   8.00e+02,   4.50e+01,   4.00e-01],
                     [  5.00e+01,   6.00e+02,   4.00e+01,   1.00e-01],
                     [  5.00e+01,   6.00e+02,   4.50e+01,   1.20e+00],
                     [  5.00e+01,   8.00e+02,   4.00e+01,   4.00e-01],
                     [  5.00e+01,   8.00e+02,   4.50e+01,   2.00e-01]])

for dim in xrange(ndim):
   vtx_delta = 2**(ndim - dim - 1)
    for vtx in xrange(vtx_delta):
        vertices[vtx, -1] += ((vertices[vtx + vtx_delta, -1] -
                               vertices[vtx, -1]) *
                              (coords[dim] -
                               vertices[vtx, dim]) /
                              (vertices[vtx + vtx_delta, dim] -
                               vertices[vtx, dim]))

print vertices[0, -1] # prints 0.55075

该函数重用顶点数组,用于导致最终值的中间插值,存储在vertices[0, -1]中,如果之后需要它,则必须复制vertices数组。< / p>