作为类代码项目的一部分,分配的一部分是构造一个算法,该算法将使用Strassen方法将两个矩阵相乘。
我已经创建了必要的算法,但是我在getSize()调用时遇到了Stack Overflow错误。任何帮助,弄清楚为什么会这样做。
我没有附加所有与该问题有关的代码。
// CODE CALLING THE MULTIPLICATION
Matrix m = new Matrix(Integer.parseInt(size.getText()), 20);
Matrix m2 = new Matrix(Integer.parseInt(size.getText()), 20);
Matrix m3 = m.strassenMult(m2);
// THE MATRIX CLASS
public class Matrix {
private int[][] vals;
private int size;
private boolean stable = false;
private long time;
// Constructor for known data
public Matrix(int[][] vals) {
this.vals = vals;
this.size = vals.length;
stable = true;
}
// Constructor for matrix with random size, if bounds = 0 it is empty
public Matrix(int size, int bounds) {
this.size = size;
vals = new int[size][size];
stable = true;
if (bounds != 0) {
Random rand = new Random(System.currentTimeMillis());
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
int x = rand.nextInt();
x = x % bounds;
vals[j][i] = x;
}
}
}
}
private Matrix add(Matrix m2) {
///////////////////////// THE ERROR IS ON THE BELOW LINE
int n = m2.getSize();
int[][] newVals = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
newVals[i][j] = vals[i][j] + m2.getVal(i, j);
}
}
Matrix sum = new Matrix(newVals);
return sum;
}
private Matrix subtract(Matrix m2) {
int n = m2.getSize();
int[][] newVals = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
newVals[i][j] = vals[i][j] - m2.getVal(i, j);
}
}
return new Matrix(newVals);
}
// The strassen multiplication algorithm
public Matrix strassenMult(Matrix m2) {
int n = m2.getSize();
int newSize = n / 2;
// initialize new sub-matricies
Matrix a11 = new Matrix(newSize, 0);
Matrix a12 = new Matrix(newSize, 0);
Matrix a21 = new Matrix(newSize, 0);
Matrix a22 = new Matrix(newSize, 0);
Matrix b11 = new Matrix(newSize, 0);
Matrix b12 = new Matrix(newSize, 0);
Matrix b21 = new Matrix(newSize, 0);
Matrix b22 = new Matrix(newSize, 0);
Matrix aResult = new Matrix(newSize, 0);
Matrix bResult = new Matrix(newSize, 0);
// divide existing matries into the sub-matricies
for (int i = 0; i < newSize; i++) {
for (int j = 0; j < newSize; j++) {
a11.set(i, j, vals[i][j]); // top left
a12.set(i, j, vals[i][j + newSize]); // top right
a21.set(i, j, vals[i + newSize][j]); // bottom left
a22.set(i, j, vals[i + newSize][j + newSize]); // bottom right
b11.set(i, j, m2.getVal(i, j)); // top left
b12.set(i, j,m2.getVal(i, j + newSize)); // top right
b21.set(i, j, m2.getVal(i + newSize, j)); // bottom left
b22.set(i, j, m2.getVal(i + newSize, j + newSize)); // bottom right
}
}
// Calculating p1 to p7:
////////////////////////////////// ERROR IS CALLED BY THIS METHOD
aResult = a11.add(a22);
bResult = b11.add(b22);
Matrix p1 = aResult.strassenMult(bResult);
// p1 = (a11+a22) * (b11+b22)
aResult = a21.add(a22); // a21 + a22
Matrix p2 = aResult.strassenMult(b11); // p2 = (a21+a22) * (b11)
bResult = b12.subtract(b22); // b12 - b22
Matrix p3 = a11.strassenMult(bResult);
// p3 = (a11) * (b12 - b22)
bResult = b21.subtract(b11); // b21 - b11
Matrix p4 = a22.strassenMult(bResult);
// p4 = (a22) * (b21 - b11)
aResult = a11.add(a12); // a11 + a12
Matrix p5 = aResult.strassenMult(b22);
// p5 = (a11+a12) * (b22)
aResult = a21.subtract(a11); // a21 - a11
bResult = b11.add(b12); // b11 + b12
Matrix p6 = aResult.strassenMult(bResult);
// p6 = (a21-a11) * (b11+b12)
aResult = a12.subtract(a22); // a12 - a22
bResult = b21.add(b22); // b21 + b22
Matrix p7 = aResult.strassenMult(bResult);
// p7 = (a12-a22) * (b21+b22)
// calculating c21, c21, c11 e c22:
Matrix c12 = p3.add(p5); // c12 = p3 + p5
Matrix c21 = p2.add(p4); // c21 = p2 + p4
aResult = p1.add(p4); // p1 + p4
bResult = aResult.add(p7); // p1 + p4 + p7
Matrix c11 = bResult.subtract(p5);
// c11 = p1 + p4 - p5 + p7
aResult = p1.add(p3); // p1 + p3
bResult = aResult.add(p6); // p1 + p3 + p6
Matrix c22 = bResult.subtract(p2);
// c22 = p1 + p3 - p2 + p6
// Grouping the results obtained in a single matrix:
Matrix result = new Matrix(n, 0);
for (int i = 0; i < newSize; i++) {
for (int j = 0; j < newSize; j++) {
result.set(i, j, c11.getVal(i, j));
result.set(i, j + newSize, c12.getVal(i, j));
result.set(i + newSize, j, c21.getVal(i, j));
result.set(i + newSize, j + newSize, c22.getVal(i, j));
}
}
return result;
}
public int getSize() {
return size;
}
答案 0 :(得分:2)
您的Strassen乘法功能缺少基本情况。在某些时候,您需要停止递归并调用不同的矩阵乘法算法。