用点简化表达式,得到正确的维数

时间:2017-05-23 14:11:16

标签: python sympy

我想创建一个包含常量的dot乘积的同情表达式,该常量在numpy.dot上转换为lambdify,我可以simplify {。}}。

不幸的是,使用sympy.Matrix总是会添加一个额外的维度,以使dot()的形状最终总是有1

import numpy
import sympy


class dot(sympy.Function):
    pass


a = sympy.Matrix([1, 2, 3])
beta = sympy.Symbol('beta')
x = sympy.Symbol('x')

expr = beta * dot(x, a)
expr = sympy.simplify(expr)
f = sympy.lambdify((x, beta), expr)

x = numpy.random.rand(7, 77, 3)
beta = numpy.random.rand(7, 77)
f(x, beta)

这给出了错误

ValueError: operands could not be broadcast together with shapes (7,77) (7,77,1) 

用等效的sympy.Matrix替换numpy.array会使dot正常工作,但在simplify(使用Python 3)失败。 (Bug filed.

我已经没有解决如何解决这个问题的想法了。任何提示?

1 个答案:

答案 0 :(得分:0)

由于this bug已在master中修复(请参阅PR),现在可以使用以下内容

import numpy
import sympy


class dot(sympy.Function):
    pass


a = numpy.array([1, 2, 3])
beta = sympy.Symbol('beta')
x = sympy.Symbol('x')

expr = beta * dot(x, a)
expr = sympy.simplify(expr)
f = sympy.lambdify((x, beta), expr)

x = numpy.random.rand(7, 77, 3)
beta = numpy.random.rand(7, 77)
f(x, beta)