将numpy中定义的函数转换为sympy

时间:2019-03-22 20:25:25

标签: python numpy sympy

我有一个在numpy中定义的函数,希望将其转换为sympy,因此可以将其应用于符号sympy变量。尝试将numpy函数直接应用于sympy变量失败:

import numpy as np
import sympy as sp

def np_fun(a):
    return np.array([np.sin(a), np.cos(a)])

x = sp.symbols('x')
sp_fun = np_fun(x)

我得到了错误

AttributeError: 'Symbol' object has no attribute 'sin'

我的下一个想法是将numpy函数转换为sympy,但是我找不到实现此目的的方法。我知道我可以通过将函数定义为sympy表达式来使此代码起作用:

sp_fun = sp.Array([sp.sin(x), sp.cos(x)])

但是我使用正弦/余弦函数作为简单示例。我正在使用的实际函数已经在numpy中定义,并且更加复杂,因此重写它会非常繁琐。

2 个答案:

答案 0 :(得分:0)

我建议使用“查找和替换”将numpy函数修改为sympy表达式。您可以在Python中使用str.replace()并定义规则以替换适合您的功能的文本。如果您发布函数,则提供更多细节会更加容易。

答案 1 :(得分:0)

原则上,您可以直接修改函数的ast(“抽象语法树”),尽管在实践中它可能会变得很繁琐。无论如何,这是您的简单示例的处理方法:

这将从源创建ast,并从void omp_set_schedule(omp_sched_t kind, int chunk_size); typedef enum omp_sched_t { omp_sched_static = 1, omp_sched_dynamic = 2, omp_sched_guided = 3, omp_sched_auto = 4 } omp_sched_t; 类派生以修改ast。节点转换器具有通用的访问方法,该方法遍历节点及其子树,委派给派生类中特定于节点的访问者。在这里,我们将所有名称/* wherever currently isDynamic is set */ if (isDynamic) { omp_set_schedule(omp_sched_dynamic, 10); } else { /* chunk_size < 1 uses default */ omp_set_schedule(static, 0); } /* later */ #pragma omp parallel for num_threads(thread_count) default(shared) private(...) schedule(runtime) for (...) { /* do a thing */ } 更改为NodeTransformer,然后将这些属性更改为拼写不同的前np现在sp。您必须将所有这些差异添加到np字典中。

最后,我们从ast编译回代码对象,并执行它以使修改后的功能可用。

sp

输出:

translate

更新简单增强功能:修改由函数调用的函数:

import ast, inspect
import numpy as np
import sympy as sp

def f(a):
    return np.array([np.sin(a), np.cos(a)])

z = ast.parse(inspect.getsource(f))

translate = {'array': 'Array'}

class np_to_sp(ast.NodeTransformer):
    def visit_Name(self, node):
        if node.id=='np':
            node = ast.copy_location(ast.Name(id='sp', ctx=node.ctx), node)
        return node
    def visit_Attribute(self, node):
        self.generic_visit(node)
        if node.value.id=='sp' and node.attr in translate:
            fields = {k: getattr(node, k) for k in node._fields}
            fields['attr'] = translate[node.attr]
            node = ast.copy_location(ast.Attribute(**fields), node)
        return node

np_to_sp().visit(z)

exec(compile(z, '', 'exec'))

x = sp.Symbol('x')
print(f(x))

打印:

[sin(x), cos(x)]