使用Python numpy einsum获取2个矩阵之间的点积

时间:2017-08-26 15:22:45

标签: python numpy matrix dot-product

刚刚遇到这个:

Vectorized way of calculating row-wise dot product two matrices with Scipy

这个numpy.einsum非常棒,但使用时有点混乱。假设我有:

import numpy as np
a = np.array([[1,2,3], [3,4,5]])
b = np.array([[0,1,2], [1,1,7]])

我如何使用" ij"在einsum获得一个"十字点产品" a和b之间?

基本上使用这个例子我想计算

的点积

[1,2,3]和[0,1,2]

[1,2,3]和[1,2,7]

[3,4,5]和[0,1,2]

[3,4,5]和[1,1,7]

最后以[[8,26],[14,42]]

结束

我知道我是否使用

np.einsum("ij,ij->i",a,b)

我最终会以[8,42]结束,这意味着我错过了"交叉"元素

1 个答案:

答案 0 :(得分:3)

您的结果仍然是二维的,因此您需要两个索引。你需要的是一个矩阵乘法,第二个数组是转置的,所以你用Int32.inheritance_chain # => "Int32 > Int > Number > Value > Object" String.inheritance_chain # => "String > Reference > Object" Float64.inheritance_chain # => "Float64 > Float > Number > Value > Object" Array(Bool).inheritance_chain # => "Array(Bool) > Reference > Object" Hash(Bool, Bool).inheritance_chain # => "Hash(Bool, Bool) > Reference > Object" Tuple(Char).inheritance_chain # => "Tuple(Char) > Struct > Value > Object" NamedTuple(s: String, b: Bool).inheritance_chain # => "NamedTuple(s: String, b: Bool) > Struct > Value > Object" Nil.inheritance_chain # => "Nil > Value > Object" Regex.inheritance_chain # => "Regex > Reference > Object" Symbol.inheritance_chain # => "Symbol > Value > Object" Proc(Int32).inheritance_chain # => "Proc(Int32) > Struct > Value > Object" Set(String).inheritance_chain # => "Set(String) > Struct > Value > Object" Exception.inheritance_chain # => "Exception > Reference > Object" Class.inheritance_chain # => "Class > Value > Object" # union alias UnionType = Int32 | Nil | String UnionType.inheritance_chain # => "(Int32 | String | Nil) > Value > Object" # nilable Int32?.inheritance_chain # => "(Int32 | Nil) > Value > Object" # pointer alias Int32Ptr = Int32* Int32Ptr.inheritance_chain # => "Pointer(Int32) > Struct > Value > Object" # ... 转换第二个矩阵而不是正常ij,jk->ik

ij,kj->ik

相当于:

np.einsum('ij,kj->ik', a, b)

#array([[ 8, 24],
#       [14, 42]])
np.dot(a, b.T)

#array([[ 8, 24],
#       [14, 42]])