我想知道为什么这个矩阵乘法在我的Scala程序中不起作用,而不是我在使用Python时得到的结果。我正在使用此数学描述的矩阵乘法算法:Matrix Multiplication其中我有两个矩阵a = n x m和b = m x p。我为这个算法编写的代码是(每个矩阵是一个2d的双精度数组):
def dot(other: Matrix2D): Matrix2D ={
if (this.shape(1) != other.shape(0)){
throw new IndexOutOfBoundsException("Matrices were not the right shape! [" + this.shape(1) + " != " + other.shape(0) + "]")
}
val n = this.shape(1) //returns the number of columns, shape(0) returns number of rows
var a = matrix.clone()
var b = other.matrix.clone()
var c = Array.ofDim[Double](this.shape(0), other.shape(1))
for(i <- 0 until c.length){
for (j <- 0 until c(0).length){
for (k <- 0 until n){
c(i)(j) += a(i)(k) * b(k)(j)
}
}
}
Matrix2D(c)
}
我在Scala和Python代码中输入的输入是:
a = [[1.0 1.0 1.0 1.0 0.0 0.0 0.0]
[1.0 1.0 0.0 1.0 0.0 0.0 0.0 ]
[1.0 1.0 1.0 1.0 1.0 1.0 1.0 ]
[1.0 0.0 0.0 0.0 1.0 1.0 1.0 ]
[1.0 0.0 0.0 0.0 1.0 0.0 1.0 ]
[1.0 0.0 0.0 0.0 0.0 0.0 0.0 ]]
b = [[0.0 0.0 0.0 ]
[0.0 -0.053430398509053074 0.021149859549078387 ]
[0.0 -0.010785871994186721 0.04942555653681449 ]
[0.0 0.04849323245519227 -0.0393881161667335 ]
[0.0 -0.03871752673999099 0.05228579488821056 ]
[0.0 0.07935206375269452 0.06511344235965408 ]
[0.0 -0.02462677123918247 1.723607966539059E-4 ]]
我从这个函数收到的输出是:
[[0.0 -0.015723038048047533 0.031187299919159375]
[0.0 -0.0049371660538608045 -0.018238256617655116]
[0.0 2.84727725473527E-4 0.14875889796367792 ]
[0.0 0.01600776577352106 0.11757159804451854 ]
[0.0 -0.06334429797917346 0.05245815568486446 ]
[0.0 0.0 0.0 ]]
与python的numpy.dot算法相比:
[[ 0. -0.01572304 0.0311873 ]
[ 0. -0.00493717 -0.01823826]
[ 0. -0.01572304 0.0311873 ]
[ 0. 0.08912777 0.07801112]
[ 0. 0.00977571 0.01289768]
[ 0. 0.08912777 0.07801112]]
我想知道为什么这个算法并没有完全填满我需要的输出算法......我一直在弄乱for循环等等,并且无法弄清楚什么是错的。< / p>
答案 0 :(得分:2)
你能展示你的Python代码吗?
我在Numpy中尝试了这个并获得与Scala代码相同的内容:
import numpy as np
a = np.array([[1.0,1.0,1.0,1.0,0.0,0.0,0.0],
[1.0, 1.0, 0.0, 1.0, 0.0,0.0,0.0 ],
[1.0, 1.0, 1.0, 1.0, 1.0,1.0,1.0 ],
[1.0, 0.0, 0.0, 0.0, 1.0 ,1.0,1.0 ],
[1.0, 0.0, 0.0, 0.0, 1.0, 0.0,1.0 ],
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0,0.0 ]])
b=np.array([[0.0 ,0.0 ,0.0 ],
[0.0 ,-0.053430398509053074 ,0.021149859549078387 ],
[0.0 ,-0.010785871994186721, 0.04942555653681449 ],
[0.0 , 0.04849323245519227 ,-0.0393881161667335 ],
[0.0 ,-0.03871752673999099 , 0.05228579488821056 ],
[0.0 , 0.07935206375269452 , 0.06511344235965408 ],
[0.0 ,-0.02462677123918247 ,1.723607966539059E-4 ]])
print a.dot(b)
打印:
[[ 0. -0.01572304 0.0311873 ]
[ 0. -0.00493717 -0.01823826]
[ 0. 0.00028473 0.1487589 ]
[ 0. 0.01600777 0.1175716 ]
[ 0. -0.0633443 0.05245816]
[ 0. 0. 0. ]]