np.einsum的速度差异

时间:2017-08-16 14:37:25

标签: python numpy numpy-einsum

我注意到np.einsum减少一维时速度更快

import numpy as np
a = np.random.random((100,100,100))
b = np.random.random((100,100,100))

%timeit np.einsum('ijk,ijk->ijk',a,b)
# 100 loops, best of 3: 3.83 ms per loop
%timeit np.einsum('ijk,ijk->ij',a,b)
# 1000 loops, best of 3: 937 µs per loop
%timeit np.einsum('ijk,ijk->i',a,b)
# 1000 loops, best of 3: 921 µs per loop
%timeit np.einsum('ijk,ijk->',a,b)
# 1000 loops, best of 3: 928 µs per loop

这对我来说似乎很奇怪,因为我希望它首先生成新数组然后对它进行求和,这显然不会发生。 那里发生了什么?为什么它会变得更快,当一个降低一个dimnesion,但在其他维度下降后不会变得更快?

旁注: 我首先想到它与创建一个大型数组有关,当它有很多维度时,我认为不是这样的:

 %timeit np.ones(a.shape)
 # 1000 loops, best of 3: 1.79 ms per loop
 %timeit np.empty(a.shape)
 # 100000 loops, best of 3: 3.05 µs per loop

因为创建新数组的速度更快。

1 个答案:

答案 0 :(得分:1)

einsum在已编译的代码numpy/numpy/core/src/multiarray/einsum.c.src中实现。

核心操作是使用100*100*100 c版本nditer迭代所有维度(例如,在您的情况下sum-of-products次),应用由ijk定义的import pandas as pd import numpy as np import plotly.figure_factory as ff import plotly.offline as py_offline py_offline.offline.init_notebook_mode() colorscale = [[0, '#FFFFFF'],[0.4, '#F8F8FF'], [1, '#F1C40F']] x = ['2015','2016','2017','2018'] y = ['March', 'February', 'January','April'] z = [[1,2.129,3,4], [0,0,1,2], [6,0,1,0], [6,0,0,2]] z_text = [] for q, arr in enumerate(np.around(z, decimals=2)): z_text.append([str(h) if h else "" for h in arr]) fig = ff.create_annotated_heatmap(z, x=x, y=y,annotation_text=z_text, colorscale=colorscale, hoverinfo='none') # Altering x axis fig['layout']['xaxis']['tickfont']['family'] = 'Gill Sans MT' fig['layout']['xaxis']['tickfont']['size'] = 12 fig['layout']['xaxis']['tickfont']['color'] = "black" fig['layout']['xaxis']['tickangle'] = 0 # Altering x axis fig['layout']['yaxis']['tickfont']['family'] = "Gill Sans MT" fig['layout']['yaxis']['tickfont']['size'] = 12 fig['layout']['yaxis']['tickfont']['color'] = "black" fig['layout']['yaxis']['tickangle'] = 25 # Altering main font fig['layout']['font'] ["family"] = "Gill Sans MT" fig['layout']['font']['size'] = 9 py_offline.iplot(fig,config={"displayModeBar": False},show_link=False,filename='pandas-heatmap') 计算key字符串。

但它会进行各种优化,包括在不需要乘法的情况下生成视图。所以需要仔细研究,看看你的案例有什么不同。

时间间隔是在产生没有求和的3d输出和在一个或多个轴上求和的输出之间。