我的一个java函数遇到了问题,它应该将2个双数组乘以矩阵。
public static double[][] matrixMultiply(double[][] m, double[][] n) {
double[][] multipliedMatrix = new double [m.length][n[0].length];
for (int i=0; i<m.length-1; i++)
{
for (int j=0; j<n[0].length-1; j++)
{
for (int k=0; k<n.length-1; k++)
{
multipliedMatrix[i][j] = multipliedMatrix[i][j] + (m[i][k] * n[k][j]);
}
}
}
return multipliedMatrix;
}
i变量应该循环遍历for循环中m(第一个矩阵)的每个元素。假设j变量循环通过第二矩阵n的每一行,并且假设变量k循环通过第一矩阵的第一行和第二矩阵的第一列中的每个元素。当输入
时,这似乎无法正常工作[[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 1.0, 2.0, 3.0]],
[[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
[1.0, 2.0, 3.0]]
它给出了
[[30.0, 36.0, 0.0],
[78.0, 96.0, 0.0],
[0.0, 0.0, 0.0]]
而不是
[[34.0, 44.0, 54.0],
[86.0, 112.0, 138.0],
[30.0, 45.0, 60.0]].
我不明白为什么会这样?
答案 0 :(得分:5)
修正:
public static double[][] matrixMultiply(double[][] m, double[][] n) {
double[][] multipliedMatrix = new double [m.length][n[0].length];
for (int i=0; i<m.length; i++)
{
for (int j=0; j<n[0].length; j++)
{
for (int k=0; k<n.length; k++)
{
multipliedMatrix[i][j] = multipliedMatrix[i][j] + (m[i][k] * n[k][j]);
}
}
}
return multipliedMatrix;
}
<强>输出强>
34.044.054.0
86.0112.0138.0
30.045.060.0
<强>解释强>
在每个循环中,您应该在索引小于length时运行 - 不小于length-1
答案 1 :(得分:0)
还可以将行/列长度保存为单独的变量。这样可以避免混淆使用什么矩阵的行/列。
<强>爪哇。矩阵乘法。
public class Matrix {
/**
* Matrix multiplication method.
* @param m1 Multiplicand
* @param m2 Multiplier
* @return Product
*/
public static double[][] multiplyByMatrix(double[][] m1, double[][] m2) {
int m1ColLength = m1[0].length; // m1 columns length
int m2RowLength = m2.length; // m2 rows length
if(m1ColLength != m2RowLength) return null; // matrix multiplication is not possible
int mRRowLength = m1.length; // m result rows length
int mRColLength = m2[0].length; // m result columns length
double[][] mResult = new double[mRRowLength][mRColLength];
for(int i = 0; i < mRRowLength; i++) { // rows from m1
for(int j = 0; j < mRColLength; j++) { // columns from m2
for(int k = 0; k < m1ColLength; k++) { // columns from m1
mResult[i][j] += m1[i][k] * m2[k][j];
}
}
}
return mResult;
}
public static String toString(double[][] m) {
String result = "";
for(int i = 0; i < m.length; i++) {
for(int j = 0; j < m[i].length; j++) {
result += String.format("%11.2f", m[i][j]);
}
result += "\n";
}
return result;
}
public static void main(String[] args) {
// #1
double[][] multiplicand = new double[][] {
{3, -1, 2},
{2, 0, 1},
{1, 2, 1}
};
double[][] multiplier = new double[][] {
{2, -1, 1},
{0, -2, 3},
{3, 0, 1}
};
System.out.println("#1\n" + toString(multiplyByMatrix(multiplicand, multiplier)));
// #2
multiplicand = new double[][] {
{1, 2, 0},
{-1, 3, 1},
{2, -2, 1}
};
multiplier = new double[][] {
{2},
{-1},
{1}
};
System.out.println("#2\n" + toString(multiplyByMatrix(multiplicand, multiplier)));
// #3
multiplicand = new double[][] {
{1, 2, -1},
{0, 1, 0}
};
multiplier = new double[][] {
{1, 1, 0, 0},
{0, 2, 1, 1},
{1, 1, 2, 2}
};
System.out.println("#3\n" + toString(multiplyByMatrix(multiplicand, multiplier)));
}
}
<强>输出:强>
#1
12.00 -1.00 2.00
7.00 -2.00 3.00
5.00 -5.00 8.00
#2
0.00
-4.00
7.00
#3
0.00 4.00 0.00 0.00
0.00 2.00 1.00 1.00