如何使用numpy einsum_path结果?

时间:2019-02-01 14:14:03

标签: python numpy numpy-einsum

我正在使用numpy einsum在某些3维和4维张量上执行相当复杂的操作。

我的实际代码是

np.einsum('oij,imj,mjkn,lnk,plk->op',phi,B,Suu,B,phi)

这就是我想要的。

使用einsum_path,结果为:

>>> path = np.einsum_path('oij,imj,mjkn,lnk,plk->op',phi,B,Suu,B,phi)

>>> print(path[0])
['einsum_path', (0, 1), (0, 3), (0, 1), (0, 1)]

>>> print(path[1])
  Complete contraction:  oij,imj,mjkn,lnk,plk->op
         Naive scaling:  8
     Optimized scaling:  5
      Naive FLOP count:  2.668e+07
  Optimized FLOP count:  1.340e+05
   Theoretical speedup:  199.136
  Largest intermediate:  7.700e+02 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   4                imj,oij->moj                     mjkn,lnk,plk,moj->op
   5               moj,mjkn->nok                          lnk,plk,nok->op
   4                plk,lnk->npk                              nok,npk->op
   4                 npk,nok->op                                   op->op

这表明理论上的加速约为200倍。

如何使用此结果来加快我的代码的速度?如何“实现” einsum_path告诉我的内容?

2 个答案:

答案 0 :(得分:5)

做一些时间测试

_TypeError (type '_InternalLinkedHashMap<dynamic, dynamic>'

在我的测试中,第二个2以相同的速度运行。对于一个小问题,which is pointing to this更快,大概是因为分析和重新安排需要时间。对于较大的问题,如果理论加速较大,则path = np.einsum_path('oij,imj,mjkn,lnk,plk->op',phi,B,Suu,B,phi) np.einsum('oij,imj,mjkn,lnk,plk->op',phi,B,Suu,B,phi, optimize=False) np.einsum('oij,imj,mjkn,lnk,plk->op',phi,B,Suu,B,phi, optimize=True) np.einsum('oij,imj,mjkn,lnk,plk->op',phi,B,Suu,B,phi, optimize=path[0]) 的实际加速速度可能大于理论速度。据推测,内存管理正在减慢optimize=False的情况。

True就是这样,仅基于FLOPS计数进行估算。仅在FLOPS主导计算的情况下才是正确的。

您还可以计时False计算。问题的大小将决定问题的时间是总时间的小还是大。

答案 1 :(得分:0)

From the source code

Theoretical Speedup = speedup = naive_cost / opt_cost
naive_cost = _flop_count(idx_contract, idx_removed, len(input_sets), idx_dict)

因此,从这个角度来看,要加快该过程,您需要降低FLOPS (Floating Point Operations Per Second)。由于天真成本是未优化表达式的成本,因此您需要以这样的方式重写表达式:删除与表达式相关联的所有“垃圾”,而无需更改表达式的基础结构。

从判断您正在做一些复杂表达式的问题出发,这可能是不可能的。但是要回答您的问题,请尝试重写表达式,这是一种更紧凑的方法,以便降低理论上的速度。

您可以尝试使用其他路径来降低FLOPS。