我尝试使用np.meshgrid
创建新矩阵,将seq.shape = (a,b,c)
之类的矩阵转换为pairwise_matrix.shape = (a,b,b,c,3)
,其中新成对矩阵中的最后一个维度是{{1}的串联来自vi
的},(vi+vj)/2
和vj
(vi
,vj
。但似乎seq
不适用于高维矩阵。还有其他运营商可以做到这一点吗?
答案 0 :(得分:0)
您需要像这样广播您的输入:
>>> import numpy as np
>>>
>>> a, b, c = 2, 3, 4
>>> seq = np.arange(a*b*c).reshape(a, b, c)
>>>
>>> weight = np.linspace(0, 1, 3)
>>> result = seq[..., None, :, None] * weight + seq[:, None, ..., None] * weight[::-1]
>>> result.shape
(2, 3, 3, 4, 3)
>>> result
array([[[[[ 0., 0., 0.],
[ 1., 1., 1.],
[ 2., 2., 2.],
[ 3., 3., 3.]],
[[ 4., 2., 0.],
[ 5., 3., 1.],
[ 6., 4., 2.],
[ 7., 5., 3.]],
[[ 8., 4., 0.],
[ 9., 5., 1.],
[10., 6., 2.],
[11., 7., 3.]]],
...
[[[12., 16., 20.],
[13., 17., 21.],
[14., 18., 22.],
[15., 19., 23.]],
[[16., 18., 20.],
[17., 19., 21.],
[18., 20., 22.],
[19., 21., 23.]],
[[20., 20., 20.],
[21., 21., 21.],
[22., 22., 22.],
[23., 23., 23.]]]]])