我是否再次低估了NumPy的力量?

时间:2013-12-05 02:25:42

标签: python arrays numpy

我认为我不能再优化我的功能,但这不是我第一次低估NumPy的力量。

鉴于:

  • 2级NumPy数组,坐标
  • 1个具有每个坐标高程的NumPy数组
  • 带有电台的Pandas DataFrame

功能:

def Function(xy_coord):
    # Apply a KDTree search for (and select) 8 nearest stations
    dist_tree_real, ix_tree_real = tree.query(xy_coord, k=8, eps=0, p=1)
    df_sel = df.ix[ix_tree_real]        

    # Fits multi-linear regression to find coefficients
    M = np.vstack((np.ones(len(df_sel['POINT_X'])),df_sel['POINT_X'], df_sel['POINT_Y'],df_sel['Elev'])).T
    b1,b2,b3 = np.linalg.lstsq(M,df_sel['TEMP'])[0][1:4]

    # Compute IDW using the coefficients
    return sum( (1/dist_tree_real)**2)**-1 * sum((df_sel['TEMP'] + (b1*(xy_coord[0] - df_sel['POINT_X'])) + 
                                                  (b2*(xy_coord[1]-df_sel['POINT_Y'])) + (b3*(dem[index]-df_sel['Elev']))) * 
                                             (1/dist_tree_real)**2)

我在坐标上应用函数如下:

for index, coord in enumerate(xy):    
    outarr[index] = func(coord)

这是一个迭代过程,如果我尝试这个outarr = np.vectorize(func)(xy)然后Python崩溃了,所以我想这是我应该避免做的事情。

我还准备了一个IPython笔记本,所以我可以编写LaTeX,这是我一直梦寐以求的事情。直到现在。这一天到来了。 Yeah


偏离主题:数学不会出现在nbviewer ..在我的本地机器上它看起来像这样:

LaTeX in my local IPython working

1 个答案:

答案 0 :(得分:1)

我的建议是不要使用DataFrame进行计算,只使用numpy数组。这是代码:

dist, idx = tree.query(xy, k=8, eps=0, p=1)
columns = ["POINT_X", "POINT_Y", "Elev", "TEMP"]
px, py, elev, tmp = df[columns].values.T[:, idx, None]
tmp = np.squeeze(tmp)
one = np.ones_like(px)

m = np.concatenate((one, px, py, elev), axis=-1)
mtm = np.einsum("ijx,ijy->ixy", m, m)
mty = np.einsum("ijx,ij->ix", m, tmp)
b1,b2,b3 = np.linalg.solve(mtm, mty)[:, 1:].T

px, py, elev = px.squeeze(), py.squeeze(), elev.squeeze()

b1 = b1[:,None]
b2 = b2[:,None]
b3 = b3[:,None]

rdist = (1/dist)**2
t0 = tmp + b1*(xy[:,0,None]-px) + b2*(xy[:,1,None]-py) + b3*(dem[:,None]-elev)
outarr = (t0*rdist).sum(1) / rdist.sum(1)

print outarr

输出:

[ -499.24287422  -540.28111668  -512.43789349  -589.75389439  -411.65598912
  -233.1779803  -1249.63803291  -232.4924416   -273.3978919   -289.35240473]

代码中有一些技巧:

    在numpy 1.8中的
  1. np.linalg.solve是一个通用的ufunc,它可以通过一次调用求解许多线性方程,但lstsq不是。所以我需要使用solve来计算lstsq

  2. 要通过一次调用进行多次矩阵乘法运算,我们无法使用doteinsum()可以解决问题,但我认为它可能比dot慢。您可以timeit获取真实数据。