我要导出一些巨大的矩阵,其中仅包含sin(q),cos(q)以及它们的和/多。 Sympy可以计算并将其导出到八度-太棒了!
但是,由于这些矩阵很大,因此我需要某种cse
甚至更好的专用优化。
我找到了this great tutorial for C code with cse。因此,我尝试自己进行了移植,但是在打印机类的某些细节上却失败了。我认为这是无限递归,导致RecursionError: maximum recursion depth exceeded
。
我的问题是:是否有一个示例将sympy-octave代码生成与优化结合在一起?或者有人可以帮助我运行所附的mwe?
import sympy as sp
t = sp.symbols('t')
from sympy.printing.octave import OctaveCodePrinter
from sympy.printing.octave import Assignment
class matlabMatrixPrinter(OctaveCodePrinter):
def _print_ImmutableDenseMatrix(self, expr):
sub_exprs, simplified = sp.cse(expr)
lines = []
for var, sub_expr in sub_exprs:
lines.append( self._print(Assignment(var, sub_expr)))
M = sp.MatrixSymbol('M', *expr.shape)
return '\n'.join(lines) + '\n' + self._print(Assignment(M, expr))
tmp = sp.sin(t)+sp.sin(t)**2
tmp = sp.ImmutableDenseMatrix((1,1,tmp))
se, ex = sp.cse(tmp)
print((ex,se))
print('\n')
#tmp = sp.Matrix([2*sp.sin(t),sp.sin(t)])
p = matlabMatrixPrinter()
print(p.doprint(tmp))
编辑:我现在发现,return语句中的第二个赋值也运行_print_ImmutableDenseMatrix函数,因此最终是递归。我不知道为什么在本教程中这对于C代码来说没有问题,但是在这里它是递归运行的。似乎只有简化表达式本身不能调用self._print函数的问题。也许有人对这些打印机有所了解,以及如何打印矩阵和这一单项作业?!
答案 0 :(得分:0)
经过大量实验,我觉得我仍然只了解codePrinter的有意工作流程背后的一些意图。但是,我编写了一个子类,该子类完全符合我的预期(请小心,因为这可能不适用于矩阵以外的任何东西!)。
也许这对某人有用!对我来说,它肯定将sympy验证为一种工作工具,因为否则,成千上万的{{1}}评估将是绝对不可行的代码。
对于仍然可以知道如何实现这些功能的某人的评论和想法,我仍然会非常感兴趣!
sin
这给出了预期的输出:
import sympy as sp
t = sp.symbols('t')
from sympy.printing.octave import OctaveCodePrinter
from sympy.printing.octave import Assignment
class matlabMatrixPrinter(OctaveCodePrinter):
def print2(self,expr_list,names=None):
sub_exprs, simplified = sp.cse(expr_list)
lines = []
for var, sub_expr in sub_exprs:
lines.append(self._print(Assignment(var, sub_expr)))
lines.append('')
for k,expr in enumerate(simplified):
if names:
M = sp.MatrixSymbol(names[k],*expr.shape)
else:
M = sp.MatrixSymbol('M{k}'.format(k=k), *expr.shape)
lines.append(self._print(Assignment(M,expr)))
result = ''
return '\n'.join(lines)
tmp = sp.Matrix([sp.sin(t)+sp.sin(t)**2 ])
tmp2 = sp.Matrix([sp.sin(t),sp.cos(t),2*sp.sin(t),sp.cos(t)**2])
p = matlabMatrixPrinter()
#print(p.print2([tmp,tmp2]))
print(p.print2([tmp,tmp2],['scalar_matrix','matrix']));
如上所述:使用时需您自担风险:)