我需要你的帮助。我的计划是使用python实现Strassen-Algorithm,但我的问题是它仅适用于相同大小的矩阵,并且如果一个矩阵大于另一个矩阵,则会出现索引超出范围的错误。
我认为解决问题的方法是用零填充新矩阵,但说实话我真的不知道这是否可行或如何实现。
def showMatrix(m):
for line in m:
print('|', end='')
i = 0
for value in line:
if (i > 0):
print(' ', end='')
print(value, end='')
i = i + 1
print("|")
def add_m(a, b):
if type(a) == int:
d = a + b
else:
d = []
for i in range(len(a)):
c = []
for j in range(len(a[0])):
c.append(a[i][j] + b[i][j])
d.append(c)
return d
def sub_m(a, b):
if type(a) == int:
d = a - b
else:
d = []
for i in range(len(a)):
c = []
for j in range(len(a[0])):
c.append(a[i][j] - b[i][j])
d.append(c)
return d
def MatMult(A, B):
n = len(A)
C = [[0 for i in range(n)] for j in range(n)]
for i in range(n):
for k in range(n):
for j in range(n):
C[i][j] += A[i][k] * B[k][j]
return C
def split(matrix): #splits matrix in quarters
a = matrix
b = matrix
c = matrix
d = matrix
while(len(a) > len(matrix)/2):
a = a[:len(a)//2]
b = b[:len(b)//2]
c = c[len(c)//2:]
d = d[len(d)//2:]
while(len(a[0]) > len(matrix[0])/2):
for i in range(len(a[0])//2):
a[i] = a[i][:len(a[i])//2]
b[i] = b[i][len(b[i])//2:]
c[i] = c[i][:len(c[i])//2]
d[i] = d[i][len(d[i])//2:]
return a,b,c,d
def strassen(a, b,q):
print('Parameter a:')
showMatrix(a)
print('Parameter b:')
showMatrix(b)
q = len(a)
if q == 1:
return MatMult(a, b)
else:
a11, a12, a21, a22 = split(a)
b11, b12, b21, b22 = split(b)
# p1 = (a11+a22) * (b11+b22)
p1 = strassen(add_m(a11,a22), add_m(b11,b22), q/2)
# p2 = (a21+a22) * b11
p2 = strassen(add_m(a21,a22), b11, q/2)
# p3 = a11 * (b12-b22)
p3 = strassen(a11, sub_m(b12,b22), q/2)
# p4 = a22 * (b12-b11)
p4 = strassen(a22, sub_m(b21,b11), q/2)
# p5 = (a11+a12) * b22
p5 = strassen(add_m(a11,a12), b22, q/2)
# p6 = (a21-a11) * (b11+b12)
p6 = strassen(sub_m(a21,a11), add_m(b11,b12), q/2)
# p7 = (a12-a22) * (b21+b22)
p7 = strassen(sub_m(a12,a22), add_m(b21,b22), q/2)
# c11 = p1 + p4 - p5 + p7
c11 = add_m(sub_m(add_m(p1, p4), p5), p7)
# c12 = p3 + p5
c12 = add_m(p3, p5)
# c21 = p2 + p4
c21 = add_m(p2, p4)
# c22 = p1 + p3 - p2 + p6
c22 = add_m(sub_m(add_m(p1, p3), p2), p6)
c = [[0 for i in range(0, q)] for i in range(0, q)]
for i in range(len(c11)):
for j in range(len(c11)):
c[i][j] = c11[i][j]
c[i][j+len(c11)] = c12[i][j]
c[i+len(c11)][j] = c21[i][j]
c[i+len(c11)][j+len(c11)] = c22[i][j]
return c
a, b = [ [3, 2, 1],[1, 0, 2]], [[1, 2],[0, 1],[4, 0]]
result = strassen(a,b,2)
showMatrix(result)
结果应该是[[7,8],[9,2]],但就像我说的那样,我得到的索引超出范围错误。