I am using Python.
Let A.shape=(n,n,M)
and B.shape=(n,n)
I want to do the following:
AB = np.array_like(A)
for m in range(M):
AB[:,:,m]=A[:,:,m] @ B
this code however does not seem like the most efficient way to do this?
答案 0 :(得分:3)
One option is to use np.einsum
:
np.einsum('ijk,jl->ilk', A, B)
Or transpose
A twice:
(A.transpose(2,0,1) @ B).transpose(1,2,0)
Example:
>>> import numpy as np
>>> A = np.arange(12).reshape(2,2,3)
>>> B = np.arange(4).reshape(2,2)
>>> AB = np.zeros_like(A)
>>> M = 3
>>> for m in range(M):
... AB[:,:,m]=A[:,:,m] @ B
...
>>> AB
array([[[ 6, 8, 10],
[ 9, 13, 17]],
[[18, 20, 22],
[33, 37, 41]]])
# einsum
>>> np.einsum('ijk,jl->ilk', A, B)
array([[[ 6, 8, 10],
[ 9, 13, 17]],
[[18, 20, 22],
[33, 37, 41]]])
# tranpose
>>> (A.transpose(2,0,1) @ B).transpose(1,2,0)
array([[[ 6, 8, 10],
[ 9, 13, 17]],
[[18, 20, 22],
[33, 37, 41]]])
答案 1 :(得分:3)
We can use np.tensordot
-
np.tensordot(A,B,axes=(1,0)).swapaxes(1,2)
Related post to understand tensordot
.
Under the hoods, it does reshaping
, alignes axes by permuting and then uses BLAS based matrix-multiplication with np.dot
. That dirty work would look something along these lines -
A.swapaxes(1,2).reshape(-1,n).dot(B).reshape(n,-1,n).swapaxes(1,2)
Starting off with B
, it would be something like this -
B.T.dot(A.swapaxes(0,1).reshape(n,-1)).reshape(n,n,-1).swapaxes(0,1)
Benchmarking
Setup -
np.random.seed(0)
n,M = 50,50
A = np.random.rand(n,n,M)
B = np.random.rand(n,n)
Timings -
# @Psidom's soln-1
In [18]: %timeit np.einsum('ijk,jl->ilk', A, B)
100 loops, best of 3: 10.2 ms per loop
# @Psidom's soln-2
In [19]: %timeit (A.transpose(2,0,1) @ B).transpose(1,2,0)
100 loops, best of 3: 10.7 ms per loop
# @Psidom's einsum soln-1 with optimize set as True
In [20]: %timeit np.einsum('ijk,jl->ilk', A, B,optimize=True)
1000 loops, best of 3: 1.17 ms per loop
In [21]: %timeit np.tensordot(A,B,axes=(1,0)).swapaxes(1,2)
1000 loops, best of 3: 1.09 ms per loop
In [22]: %timeit A.swapaxes(1,2).reshape(-1,n).dot(B).reshape(n,-1,n).swapaxes(1,2)
1000 loops, best of 3: 1.03 ms per loop
In [23]: %timeit B.T.dot(A.swapaxes(0,1).reshape(n,-1)).reshape(n,n,-1).swapaxes(0,1)
1000 loops, best of 3: 951 µs per loop