我目前正在开发一个表示矩阵的类,它代表任何通用的mxn矩阵。我已经计算出了加法和标量乘法,但我正在努力开发两个矩阵的乘法。矩阵的数据保存在二维的二维数组中。
该方法看起来有点像这样:
public Matrix multiply(Matrix A) {
////code
}
它将返回产品矩阵。这是右边的乘法。因此,如果我调用A.multiply(B),那么它将返回矩阵AB,右边是B.
我还不用担心检查乘法是否在给定矩阵上定义,我可以假设我将得到正确尺寸的矩阵。
有没有人知道一个简单的算法,甚至可能在伪代码中执行乘法过程?
提前致谢。
答案 0 :(得分:10)
数学上矩阵的乘积A(l x m)和B(m x n)被定义为由元素组成的矩阵C(l×n):
m
c_i_j = ∑ a_i_k * b_k_j
k=1
因此,如果你的速度不是太高,你可能会对直接的O(n ^ 3)实现感到满意:
for (int i=0; i<l; ++i)
for (int j=0; j<n; ++j)
for (int k=0; k<m; ++k)
c[i][j] += a[i][k] * b[k][j]
如果您想要提高速度,可能需要检查其他替代方案,例如Strassen算法(请参阅:Strassen算法)。
然而要注意 - 特别是如果你在现代处理器架构上乘以小矩阵,速度很大程度上取决于矩阵数据和乘法顺序,以便最好地利用缓存线。
我强烈怀疑有没有机会用vm影响这个因素,所以我不确定是否要考虑这个因素。
答案 1 :(得分:1)
<强>爪哇。矩阵乘法。
这是“执行乘法过程的代码”。用不同大小的矩阵进行测试。
public class Matrix {
/**
* Matrix multiplication method.
* @param m1 Multiplicand
* @param m2 Multiplier
* @return Product
*/
public static double[][] multiplyByMatrix(double[][] m1, double[][] m2) {
int m1ColLength = m1[0].length; // m1 columns length
int m2RowLength = m2.length; // m2 rows length
if(m1ColLength != m2RowLength) return null; // matrix multiplication is not possible
int mRRowLength = m1.length; // m result rows length
int mRColLength = m2[0].length; // m result columns length
double[][] mResult = new double[mRRowLength][mRColLength];
for(int i = 0; i < mRRowLength; i++) { // rows from m1
for(int j = 0; j < mRColLength; j++) { // columns from m2
for(int k = 0; k < m1ColLength; k++) { // columns from m1
mResult[i][j] += m1[i][k] * m2[k][j];
}
}
}
return mResult;
}
public static String toString(double[][] m) {
String result = "";
for(int i = 0; i < m.length; i++) {
for(int j = 0; j < m[i].length; j++) {
result += String.format("%11.2f", m[i][j]);
}
result += "\n";
}
return result;
}
public static void main(String[] args) {
// #1
double[][] multiplicand = new double[][] {
{3, -1, 2},
{2, 0, 1},
{1, 2, 1}
};
double[][] multiplier = new double[][] {
{2, -1, 1},
{0, -2, 3},
{3, 0, 1}
};
System.out.println("#1\n" + toString(multiplyByMatrix(multiplicand, multiplier)));
// #2
multiplicand = new double[][] {
{1, 2, 0},
{-1, 3, 1},
{2, -2, 1}
};
multiplier = new double[][] {
{2},
{-1},
{1}
};
System.out.println("#2\n" + toString(multiplyByMatrix(multiplicand, multiplier)));
// #3
multiplicand = new double[][] {
{1, 2, -1},
{0, 1, 0}
};
multiplier = new double[][] {
{1, 1, 0, 0},
{0, 2, 1, 1},
{1, 1, 2, 2}
};
System.out.println("#3\n" + toString(multiplyByMatrix(multiplicand, multiplier)));
}
}
<强>输出:强>
#1
12.00 -1.00 2.00
7.00 -2.00 3.00
5.00 -5.00 8.00
#2
0.00
-4.00
7.00
#3
0.00 4.00 0.00 0.00
0.00 2.00 1.00 1.00
答案 2 :(得分:0)
在这个答案中,我创建了一个名为Matrix的类,另一个类称为MatrixOperations,它定义了可以对矩阵执行的各种操作(当然除了行操作)。但我将从MatrixOperations中提取乘法代码。完整项目可在我的GitHub页面here上找到。
以下是Matrix类的定义。
package app.matrix;
import app.matrix.util.MatrixException;
public class Matrix {
private double[][] entries;
public void setEntries(double[][] entries) {
this.entries = entries;
}
private String name;
public double[][] getEntries() {
return entries;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public class Dimension {
private int rows;
private int columns;
public int getRows() {
return rows;
}
public void setRows(int rows) {
this.rows = rows;
}
public int getColumns() {
return columns;
}
public void setColumns(int columns) {
this.columns = columns;
}
public Dimension(int rows, int columns) {
this.setRows(rows);
this.setColumns(columns);
}
@Override
public boolean equals(Object obj) {
if(obj instanceof Dimension){
return (this.getColumns() == ((Dimension) obj).getColumns()) && (this.getRows() == ((Dimension) obj).getRows());
}
return false;
}
}
private Dimension dimension;
public Dimension getDimension() {
return dimension;
}
public void setDimension(Dimension dimension) {
this.dimension = dimension;
}
public Matrix(int dimension, String name) throws MatrixException {
if (dimension == 0) throw new MatrixException(ZERO_UNIT_DIMENSION);
else this.setEntries(new double[Math.abs(dimension)][Math.abs(dimension)]);
this.setDimension(new Dimension(dimension, dimension));
this.setName(name);
}
public Matrix(int dimensionH, int dimensionV, String name) throws MatrixException {
if (dimensionH == 0 || dimensionV == 0) throw new MatrixException(ZERO_UNIT_DIMENSION);
else this.setEntries(new double[Math.abs(dimensionH)][Math.abs(dimensionV)]);
this.setDimension(new Dimension(dimensionH, dimensionV));
this.setName(name);
}
private static final String OVERFLOW_ITEMS_MSG = "The values are too many for the matrix's specified dimensions";
private static final String ZERO_UNIT_DIMENSION = "Zero cannot be a value for a dimension";
public Matrix(int dimensionH, int dimensionV, String name, double... values) throws MatrixException {
if (dimensionH == 0 || dimensionV == 0) throw new MatrixException(ZERO_UNIT_DIMENSION);
else if (values.length > dimensionH * dimensionV) throw new MatrixException(Matrix.OVERFLOW_ITEMS_MSG);
else this.setEntries(new double[Math.abs(dimensionH)][Math.abs(dimensionV)]);
this.setDimension(new Dimension(dimensionH, dimensionV));
this.setName(name);
int iterator = 0;
int j;
for (int i = 0; i < dimensionH; i++) {
j = 0;
while (j < dimensionV) {
this.entries[i][j] = values[iterator];
j++;
iterator++;
}
}
}
public Matrix(Dimension dimension) throws MatrixException {
this(dimension.getRows(), dimension.getColumns(), null);
}
public static Matrix identityMatrix(int dim) throws MatrixException {
if (dim == 0) throw new MatrixException(ZERO_UNIT_DIMENSION);
double[] i = new double[dim * dim];
int constant = dim + 1;
for (int j = 0; j < i.length; j = j + constant) {
i[j] = 1.0;
}
return new Matrix(dim, dim, null, i);
}
public String toString() {
StringBuilder builder = new StringBuilder("Matrix \"" + (this.getName() == null ? "Null Matrix" : this.getName()) + "\": {\n");
for (int i = 0; i < this.getDimension().getRows(); i++) {
for (int j = 0; j < this.getDimension().getColumns(); j++) {
if (j == 0) builder.append("\t");
builder.append(this.entries[i][j]);
if (j != this.getDimension().getColumns() - 1)
builder.append(", ");
}
if (i != this.getDimension().getRows()) builder.append("\n");
}
builder.append("}");
return builder.toString();
}
public boolean isSquare() {
return this.getDimension().getColumns() == this.getDimension().getRows();
}
}
这是MatrixOperations
中矩阵乘法的代码方法public static Matrix multiply(Matrix matrix1, Matrix matrix2) throws MatrixException {
if (matrix1.getDimension().getColumns() != matrix2.getDimension().getRows())
throw new MatrixException(MATRIX_MULTIPLICATION_ERROR_MSG);
Matrix retVal = new Matrix(matrix1.getDimension().getRows(), matrix2.getDimension().getColumns(), matrix1.getName() + " x " + matrix2.getName());
for (int i = 0; i < matrix1.getDimension().getRows(); i++) {
for (int j = 0; j < matrix2.getDimension().getColumns(); j++) {
retVal.getEntries()[i][j] = sum(arrayProduct(matrix1.getEntries()[i], getColumnMatrix(matrix2, j)));
}
}
return retVal;
}
以下是方法sum,arrayProduct和getColumnMatrix的代码
private static double sum(double... values) {
double sum = 0;
for (double value : values) {
sum += value;
}
return sum;
}
private static double[] arrayProduct(double[] arr1, double[] arr2) throws MatrixException {
if (arr1.length != arr2.length) throw new MatrixException("Array lengths must be the same");
double[] retVal = new double[arr1.length];
for (int i = 0; i < arr1.length; i++) {
retVal[i] = arr1[i] * arr2[i];
}
return retVal;
}
private static double[] getColumnMatrix(Matrix matrix, int col) {
double[] ret = new double[matrix.getDimension().getRows()];
for (int i = 0; i < matrix.getDimension().getRows(); i++) {
ret[i] = matrix.getEntries()[i][col];
}
return ret;
}
答案 3 :(得分:0)
尝试使用此代码处理多个任意维数组并打印。认为这更简单,任何人都可以理解。
public class Test
{
public static void main(String[] args)
{
int[][] array1 = {
{ 1, 4, -2 },
{ 3, 5, -6 },
{ 4, 5, 2 }
};
int[][] array2 = {
{ 5, 2, 8, -1 },
{ 3, 6, 4, 5 },
{ -2, 9, 7, -3 }
};
Test test = new Test();
test.printArray(test.multiplication(array1, array2));
}
private int[][] multiplication(int[][] array1, int[][] array2)
{
int r1, r2, c1, c2;
r1 = array1.length;
c1 = array1[0].length;
r2 = array2.length;
c2 = array2[0].length;
int[][] result;
if (c1 != r2)
{
System.out.println("Error!");
result = new int[0][0];
}
else
{
result = new int[r1][c2];
for (int i = 0; i < r1; i++)//2
{
for (int j = 0; j < c2; j++)//4
{
for (int k = 0; k < c1; k++)
{
result[i][j] += array1[i][k] * array2[k][j];
}
}
}
}
return result;
}
private void printArray(int[][] array)
{
for (int[] arr : array)
{
for (int element : arr)
{
System.out.print(element + " ");
}
System.out.println();
}
}
}