我正在尝试扩展numpy的“ tensordot”,例如:
K_ijklm = A_ki * B_jml
可以这样清晰地写:K = mytensordot(A,B,[2,0],[1,4,3])
据我了解,numpy的张量点(带有可选参数0)将能够执行以下操作:K_kijml = A_ki * B_jml
,即保持索引的顺序。因此,我将不得不做许多np.swapaxes()
来获得矩阵“ K_ijklm”,在复杂的情况下,矩阵很容易成为错误的来源(可能很难调试)。
问题是,即使使用numba,我的实现也很慢(比tensordot慢10倍[编辑:实际上比tensordot慢得多])。我想知道是否有人会为提高算法性能做些什么。
import numpy as np
import numba as nb
import itertools
import timeit
@nb.jit()
def myproduct(dimN):
N=np.prod(dimN)
L=len(dimN)
Product=np.zeros((N,L),dtype=np.int32)
rn=0
for n in range(1,N):
for l in range(L):
if l==0:
rn=1
v=Product[n-1,L-1-l]+rn
rn = 0
if v == dimN[L-1-l]:
v = 0
rn = 1
Product[n,L-1-l]=v
return Product
@nb.jit()
def mytensordot(A,B,iA,iB):
iA,iB = np.array(iA,dtype=np.int32),np.array(iB,dtype=np.int32)
dimA,dimB = A.shape,B.shape
NdimA,NdimB=len(dimA),len(dimB)
if len(iA) != NdimA: raise ValueError("iA must be same size as dim A")
if len(iB) != NdimB: raise ValueError("iB must be same size as dim B")
NdimN = NdimA + NdimB
dimN=np.zeros(NdimN,dtype=np.int32)
dimN[iA]=dimA
dimN[iB]=dimB
Out=np.zeros(dimN)
indexes = myproduct(dimN)
for nidxs in indexes:
idxA = tuple(nidxs[iA])
idxB = tuple(nidxs[iB])
v=A[(idxA)]*B[(idxB)]
Out[tuple(nidxs)]=v
return Out
A=np.random.random((4,5,3))
B=np.random.random((6,4))
def runmytdot():
return mytensordot(A,B,[0,2,3],[1,4])
def runtensdot():
return np.tensordot(A,B,0).swapaxes(1,3).swapaxes(2,3)
print(np.all(runmytdot()==runtensdot()))
print(timeit.timeit(runmytdot,number=100))
print(timeit.timeit(runtensdot,number=100))
True
1.4962144780438393
0.003484356915578246
答案 0 :(得分:1)
<?php
function returnString() {
$name = $_POST['postname'];
echo "the name entered ->", $name, " <- hier";
return $name;
}
returnString();
可能难以理解。我在
How does numpy.tensordot function works step-by-step?
我推断出tensordot
与np.tensordot(A, B, axes=0)
是等效的。
axes=[[], []]
这反过来相当于以新的大小为1的产品和维度调用In [757]: A=np.random.random((4,5,3))
...: B=np.random.random((6,4))
In [758]: np.tensordot(A,B,0).shape
Out[758]: (4, 5, 3, 6, 4)
In [759]: np.tensordot(A,B,[[],[]]).shape
Out[759]: (4, 5, 3, 6, 4)
:
dot
In [762]: np.dot(A[...,None],B[...,None,:]).shape
Out[762]: (4, 5, 3, 6, 4)
(4,5,3,1) * (6,1,4) # the 1 is the last of A and 2nd to the last of B
使用BLAS(或等效代码)的速度很快。交换轴和重塑也相对较快。
dot
使我们对轴有很多控制
复制以上产品:
einsum
并进行交换:
In [768]: np.einsum('jml,ki->jmlki',A,B).shape
Out[768]: (4, 5, 3, 6, 4)
次要点-双重交换可以写为一个转置:
In [769]: np.einsum('jml,ki->ijklm',A,B).shape
Out[769]: (4, 4, 6, 3, 5)
答案 1 :(得分:0)
您遇到了a known issue。创建多维数组时,numpy.zeros
requires a tuple。如果您传递的不是元组,则有时可以使用,但这仅是因为numpy
可以将对象首先转换为元组。
问题在于numba
当前不支持conversion of arbitrary iterables into tuples。因此,当您尝试在nopython=True
模式下编译时,此行失败。 (其他几个也失败了,但这是第一个。)
Out=np.zeros(dimN)
从理论上讲,您可以调用np.prod(dimN)
,创建一个零的平面数组,然后对其进行整形,但是随后遇到了一个完全相同的问题:reshape
数组的numpy
方法需要一个元组!
numba
这是一个非常烦人的问题-我以前从未遇到过。我真的怀疑我找到的解决方案是正确的,但这是一个可行的解决方案,它允许我们以nopython=True
模式编译版本。
核心思想是通过直接实现紧跟数组strides
的索引器来避免使用元组进行索引:
@nb.jit(nopython=True)
def index_arr(a, ix_arr):
strides = np.array(a.strides) / a.itemsize
ix = int((ix_arr * strides).sum())
return a.ravel()[ix]
@nb.jit(nopython=True)
def index_set_arr(a, ix_arr, val):
strides = np.array(a.strides) / a.itemsize
ix = int((ix_arr * strides).sum())
a.ravel()[ix] = val
这使我们无需元组即可获取和设置值。
我们还可以避免使用reshape
,方法是将输出缓冲区传递到jitted函数中,并将该函数包装在辅助函数中:
@nb.jit() # We can't use nopython mode here...
def mytensordot(A, B, iA, iB):
iA, iB = np.array(iA, dtype=np.int32), np.array(iB, dtype=np.int32)
dimA, dimB = A.shape, B.shape
NdimA, NdimB = len(dimA), len(dimB)
if len(iA) != NdimA:
raise ValueError("iA must be same size as dim A")
if len(iB) != NdimB:
raise ValueError("iB must be same size as dim B")
NdimN = NdimA + NdimB
dimN = np.zeros(NdimN, dtype=np.int32)
dimN[iA] = dimA
dimN[iB] = dimB
Out = np.zeros(dimN)
return mytensordot_jit(A, B, iA, iB, dimN, Out)
由于辅助程序不包含循环,因此增加了一些开销,但开销却微不足道。这是最后的固定功能:
@nb.jit(nopython=True)
def mytensordot_jit(A, B, iA, iB, dimN, Out):
for i in range(np.prod(dimN)):
nidxs = int_to_idx(i, dimN)
a = index_arr(A, nidxs[iA])
b = index_arr(B, nidxs[iB])
index_set_arr(Out, nidxs, a * b)
return Out
不幸的是,这并没有像我们希望的那样产生尽可能多的加速。在数组上,它比tensordot
慢5倍;在较大的阵列上,速度仍然慢50倍。 (但至少不慢1000倍!)回想起来,这并不奇怪,因为dot
和tensordot
都在后台使用BLAS,就像@hpaulj reminds us一样。
完成此代码后,我看到einsum
解决了您的实际问题-太好了!
但是您最初的问题所指向的潜在问题-在jitted代码中不可能使用任意长度的元组进行索引-仍然令人沮丧。因此希望这对其他人有用!