在numpy中求解大量的小方程系统

时间:2012-11-17 15:48:43

标签: python numpy linear-algebra

我有大量的小线性方程系统,我想用numpy有效地解决。基本上,在给定A[:,:,:]b[:,:]的情况下,我希望找到x[:,:]给出的A[i,:,:].dot(x[i,:]) = b[i,:]。所以如果我不关心速度,我可以解决这个问题

for i in range(n):
    x[i,:] = np.linalg.solve(A[i,:,:],b[i,:])

但由于这涉及到python中的显式循环,并且由于A通常具有类似(1000000,3,3)的形状,因此这种解决方案会非常慢。如果numpy不符合这个,我可以在fortran中进行这个循环(即使用f2py),但如果可能的话我宁愿留在python中。

4 个答案:

答案 0 :(得分:3)

我想回答自己是一个失礼,但这是我现在遇到的Fortran解决方案,即其他解决方案在速度和简洁方面都有效竞争。

function pixsolve(A, b) result(x)
    implicit none
    real*8    :: A(:,:,:), b(:,:), x(size(b,1),size(b,2))
    integer*4 :: i, n, m, piv(size(b,1)), err
    n = size(A,3); m = size(A,1)
    x = b
    do i = 1, n
        call dgesv(m, 1, A(:,:,i), m, piv, x(:,i), m, err)
    end do
end function

这将编译为:

f2py -c -m foo{,.f90} -llapack -lblas

从python调用

x = foo.pixsolve(A.T, b.T).T

(由于f2py中的设计选择不佳,需要.T,如果省略.T,这两者都会导致不必要的复制,低效的内存访问模式和不自然的查找fortran索引。)

这也避免了setup.py等。我没有用fortran选择的骨头(只要字符串不涉及),但我希望numpy可能有一些短而优雅的东西可以做同样的事情的事情。

答案 1 :(得分:3)

对于那些现在回来阅读这个问题的人,我想我会节省其他时间,并提到numpy现在使用广播处理这个问题。

因此,在numpy 1.8.0及更高版本中,以下内容可用于求解N个线性方程。

x = np.linalg.solve(A,b)

答案 2 :(得分:2)

我认为你明确循环是一个问题是错的。通常它只是最内层的循环,值得优化,我认为这也适用于此。例如,我们可以测量开销的代码与实际计算的代价:

import numpy as np

n = 10**6
A = np.random.random(size=(n, 3, 3))
b = np.random.random(size=(n, 3))
x = b*0

def f():
    for i in xrange(n):
        x[i,:] = np.linalg.solve(A[i,:,:],b[i,:])

np.linalg.pseudosolve = lambda a,b: b

def g():
    for i in xrange(n):
        x[i,:] = np.linalg.pseudosolve(A[i,:,:],b[i,:])

给了我

In [66]: time f()
CPU times: user 54.83 s, sys: 0.12 s, total: 54.94 s
Wall time: 55.62 s

In [67]: time g()
CPU times: user 5.37 s, sys: 0.01 s, total: 5.38 s
Wall time: 5.40 s

IOW,除了实际解决问题之外,它只花费10%的时间做其他事情。现在,我完全相信np.linalg.solve本身对于你从Fortran中得到的东西来说太慢了,所以你想要做别的事情。在小问题上尤其如此,想一想:IIRC我曾经发现手动展开某些小解决方案的速度更快,尽管那是一段时间了。

但就其本身而言,在第一个索引上使用显式循环并不会使整个解决方案变得非常缓慢。如果np.linalg.solve足够快,那么循环在这里不会改变太多。

答案 3 :(得分:0)

我认为你可以一次性完成,对角线周围有3x3块组成的(3x100000,3x100000)矩阵。

未经测试:

b_new = np.vstack([ b[i,:] for i in range(len(i)) ])
x_new = np.zeros(shape=(3x10000,3) )

A_new = np.zeros(shape=(3x10000,3x10000) )
n,m = A.shape
for i in range(n):
   A_new[3*i:3*(i+1),3*i:3*(i+1)] = A[i,:,:]

x = np.linalg.solve(A_new,b_new)