我有以下张量执行,
np.einsum('k,pjqk,yzjqk,yzk,ipqt->it', A, B, C, D, E)
我注意到当'z'或'q'在维度上扩展时,执行时间确实受到影响,虽然我的直觉是它可能不应该那么糟糕 - 也许这是我的输入形式,我可以通过手动优化张量收缩。
经过一番挖掘,我发现优化有两种模式:'最佳'和'贪婪'。如果我分别针对两种模式评估了我的路径:
(['einsum_path', (0, 3), (1, 3), (0, 2), (0, 1)],
' Complete contraction: k,pjqk,yzjqk,yzk,ipqt->it\n'
' Naive scaling: 8\n'
' Optimized scaling: 5\n'
' Naive FLOP count: 5.530e+04\n'
' Optimized FLOP count: 7.930e+02\n'
' Theoretical speedup: 69.730\n'
' Largest intermediate: 2.400e+01 elements\n'
'--------------------------------------------------------------------------\n'
'scaling current remaining\n'
'--------------------------------------------------------------------------\n'
' 3 yzk,k->yzk pjqk,yzjqk,ipqt,yzk->it\n'
' 5 yzk,yzjqk->jqk pjqk,ipqt,jqk->it\n'
' 4 jqk,pjqk->qp ipqt,qp->it\n'
' 4 qp,ipqt->it it->it')
和
(['einsum_path', (2, 3), (1, 3), (1, 2), (0, 1)],
' Complete contraction: k,pjqk,yzjqk,yzk,ipqt->it\n'
' Naive scaling: 8\n'
' Optimized scaling: 5\n'
' Naive FLOP count: 5.530e+04\n'
' Optimized FLOP count: 1.729e+03\n'
' Theoretical speedup: 31.981\n'
' Largest intermediate: 4.800e+01 elements\n'
'--------------------------------------------------------------------------\n'
'scaling current remaining\n'
'--------------------------------------------------------------------------\n'
' 5 yzk,yzjqk->jqk k,pjqk,ipqt,jqk->it\n'
' 4 jqk,pjqk->qkp k,ipqt,qkp->it\n'
' 5 qkp,ipqt->tik k,tik->it\n'
' 3 tik,k->it it->it')
测试结果表明,“最佳”对我来说要快得多,如图所示。
任何人都可以用简单的术语解释差异是什么以及为什么'贪婪'被设置为默认值?
始终使用'最佳'有什么缺点?
如果我的einsum计算将运行1000次(它是优化迭代的一部分),我应该重新构建执行以自动从“最佳”路径中受益,而不必重新计算它(或者每次'贪婪的'路径?)
答案 0 :(得分:1)
对于发现这一点的人来说,阅读更多的内容显示以下内容:
“贪婪”通常非常有效,在大多数情况下会产生“最佳”解决方案,并且执行起来更快。对于可能在迭代循环中无意中使用einsum的普通用户来说,将“贪婪”作为默认值就足够了。否则,对于一次性计算,似乎“最优”的最小额外开销意味着它可以有效地使用,除非可能用于大量索引,并且它可能提供很大的提升(如我的情况)
在循环中,最好的办法是预先计算它(或在第一次迭代中计算并更新非局部变量)并将其作为参数提供:
path, display = np.einsum_path('k,pjqk,yzjqk,yzk,ipqt->it', A, B, C, D, E, optimize='optimal')
for i in range(BIG_INT):
# other things
calculation = np.einsum_path('k,pjqk,yzjqk,yzk,ipqt->it', A, B, C, D, E, optimize=path)
# more things