使用Python中的MatrixSymbol简化矩阵表达式

时间:2018-09-06 00:39:07

标签: python sympy

我正在使用SymPy进行符号矩阵计算,但是某些语句非常大。似乎有一种方法可以进一步简化它们。我使用过simplify(),但未能成功获得想要的东西。

例如,下面的图像是一个矩阵,它是通过执行以前的矩阵计算的长列表而获得的。 The output of series of matrix calculations, which requires further simplification

最后一条语句有两个加法和一个矩阵乘法。我想知道是否有什么方法也可以在右边执行矩阵乘法,所以我们可以简单地得到3个矩阵求和?

我知道这可以通过手工执行某些代数运算来完成,但是我对要执行的命令更感兴趣,因此该命令将整个语句作为输入,并进行所有简化,包括任何乘法,加法和加法和输出我需要的。所有这些都应使用sympy完成。换句话说,如果可以进行加法或乘法运算,那么我希望它完成而不会被遗忘。

这是模仿我的问题的MCVE

from sympy import *
init_printing()
J_22 = MatrixSymbol('J_22', 3, 3)
COV_b=Matrix([[2,1,1],[1,2,1],[1,1,2]])
(COV_b+J_22)*COV_b

此代码的输出为

The output of the MCVE

但是,我希望将其作为输出

The desired output

我了解在这个简单的示例中,我可以通过以下代码简单地解决问题

from sympy import *
init_printing()
J_22 = MatrixSymbol('J_22', 3, 3)
COV_b=Matrix([[2,1,1],[1,2,1],[1,1,2]])
(COV_b*COV_b+J_22*COV_b)

但是,这只是一个简单的示例,在实际问题中,在生成输出之前看不到它。因此,我希望能够使用一个命令,该命令将第一个提供的代码的输出作为输入并输出所需的输出。

更新:@WelcometoStackOverflow提供了一个函数,该函数简化了很多事情,但仍使矩阵加法运算未完成。

from sympy import *
init_printing()
J_22 = MatrixSymbol('J_22', 3, 3)
COV_b=Matrix([[2,1,1],[1,2,1],[1,1,2]])
T=(COV_b+J_22)*COV_b+COV_b
def expand_matmul(expr):
    import itertools
    for a in preorder_traversal(expr):
        if isinstance(a, MatMul):
            terms = [f.args if isinstance(f, MatAdd) else [f] for f in a.args]
            expanded = Add(*[MatMul(*t) for t in itertools.product(*terms)])
            if a != expanded:
                expr = expr.xreplace({a: expanded})
                return expand_matmul(expr)
    return expr
expand_matmul(T)

输出为

enter image description here[4]

,前两个矩阵之间的和仍未执行。

1 个答案:

答案 0 :(得分:1)

这是SymPy表达式Can't expand matrix expression的一个已知老问题。矩阵表达式模块很有用,但不是SymPy中最活跃的模块。我组合了一个函数来扩展此类内容。

def expand_matmul(expr):
    import itertools
    for a in preorder_traversal(expr):
        if isinstance(a, MatMul) and any(isinstance(f, MatAdd) for f in a.args):
            terms = [f.args if isinstance(f, MatAdd) else [f] for f in a.args]
            expanded = MatAdd(*[MatMul(*t) for t in itertools.product(*terms)])
            if a != expanded:
                expr = expr.xreplace({a: expanded})
                return expand_matmul(expr)
    return expr

该函数从最高级别遍历表达式树,以寻找扩展MatMul的机会。返回的表达式可以受益于doit方法调用,可以从显式矩阵执行任何撤消的乘法,如下例所示。

J_22 = MatrixSymbol('J_22', 3, 3)
COV_b=Matrix([[2,1,1],[1,2,1],[1,1,2]])
T=(COV_b+J_22)*COV_b+COV_b  
pprint(expand_matmul(T).doit())

打印

⎡8  6  6⎤       ⎡2  1  1⎤
⎢       ⎥       ⎢       ⎥
⎢6  8  6⎥ + J₂₂⋅⎢1  2  1⎥
⎢       ⎥       ⎢       ⎥
⎣6  6  8⎦       ⎣1  1  2⎦