如何让numpy.einsum与sympy一起玩?

时间:2013-03-25 03:01:30

标签: python numpy nested sympy

好的,所以我有几个多维的numpy数组的sympy对象(表达式)。例如:

A = array([[1.0*cos(z0)**2 + 1.0, 1.0*cos(z0)],
          [1.0*cos(z0), 1.00000000000000]], dtype=object)

等等。

我想要做的是使用einsum将这些数组中的几个相乘,因为我已经通过我之前做过的数值计算得到了它的语法。问题是,当我尝试做类似

的事情时
einsum('ik,jkim,j', A, B, C)

我收到类型错误:

TypeError: invalid data type for einsum

当然,所以在Google上进行快速搜索会让我看到einsum可能无法做到这一点,但没有理由说明原因。特别是,在这些数组上调用numpy.dot()和numpy.tensordot()函数就像一个魅力。我可以使用tensordot来做我需要的事情,但当我考虑必须用嵌套替换50个左右的Einsten总结(其中indeces的顺序非常重要)时,我的大脑会受到伤害tenordot电话。更为噩梦的是,必须调试该代码并寻找那个错位的索引交换。

长话短说,有没有人知道为什么讲故事与物体一起工作但是einsum不会?有关解决方法的任何建议吗?如果没有,关于如何编写我自己的包装器到嵌套的tensordot调用的任何建议有点类似于einsum表示法(数字而不是字母都可以)?

3 个答案:

答案 0 :(得分:4)

Einsum基本上取代了tensordot(不是dot,因为dot通常使用优化的线性代数包),代码方面它完全不同。

这是一个对象einsum,它未经测试(对于更复杂的东西),但我认为它应该有效...在C中做同样的事情可能更简单,因为你可以从真正的einsum窃取除循环本身之外的所有东西功能。因此,如果您愿意,请实施并让更多人满意......

https://gist.github.com/seberg/5236560

我不保证任何东西,特别是对于怪异角落的情况。当然你也可以将einsum符号转换为tensordot表示法,我确信,这可能会快一点,因为循环最终会主要在C ...

答案 1 :(得分:3)

有趣的是,添加 optimize="optimal" 对我有用

einsum('ik,jkim,j', A, B, C) 产生错误,但是

einsum('ik,jkim,j', A, B, C, optimize="optimal") 与 sympy 完美配合。

答案 2 :(得分:2)

这是一个更简单的实现,它将多个einsum中的tensordot分开。

def einsum(string, *args):
    index_groups = map(list, string.split(','))
    assert len(index_groups) == len(args)
    tensor_indices_tuples = zip(index_groups, args)
    return reduce(einsum_for_two, tensor_indices_tuples)[1]

def einsum_for_two(tensor_indices1, tensor_indices2):
    string1, tensor1 = tensor_indices1
    string2, tensor2 = tensor_indices2
    sum_over_indices = set(string1).intersection(set(string2))
    new_string = string1 + string2
    axes = ([], [])
    for i in sum_over_indices:
        new_string.remove(i)
        new_string.remove(i)
        axes[0].append(string1.index(i))
        axes[1].append(string2.index(i))
    return new_string, np.tensordot(tensor1, tensor2, axes)

首先,它分隔(indices,tensor)元组中的einsum参数。然后它按如下方式减少列表:

  • 获取前两个元组,并在它们上评估一个简单的einsum_for_two。它还打印出新的索引签名。
  • einsum_for_two的值与列表中的下一个元组一起用作einsum_for_two的新参数。
  • 继续直到只留下元组。索引签名被丢弃,只返回张量。

可能很慢(但无论如何你使用object dtype)。它没有对输入进行很多正确性检查。

正如@seberg所说,我的代码不适用于张量的痕迹。