我很难概念化如何实施Strassen的此算法版本。
对于后台,我对迭代版本有以下伪代码:
def Matrix(a,b):
result = []
for i in range(0,len(a)):
new_array = []
result.extend(new_array)
for j in range(0,len(b[0])):
ssum = 0
for k in range(0,len(a[0])):
ssum += a[i][k] * b[k][j]
result[i][j] = ssum
return result
我还有以下伪代码用于初始分而治之的版本:
def recMatrix(x,y):
if(len(x) == 1):
return x[0] * y[0]
z = []
z[0] = recMatrix(x[0], y[0]) + recMatrix(x[1], y[2])
z[1] = recMatrix(x[0], y[1]) + recMatrix(x[1], y[3])
z[2] = recMatrix(x[2], y[0]) + recMatrix(x[3], y[2])
z[3] = recMatrix(x[2], y[1]) + recMatrix(x[3], y[3])
return z
所以我的问题是,我对分而治之方法的理解是否正确,如果是这样,我如何修改以允许Strassen的方法? (这不是作业。)
(特别是我很难概念化它是在递归之前(或之后)实体本身的数学中。即P1 = A(FH)。如果递归主动乘以基本元素,如何是strassen递归在矩阵上处理算术(加法和减法)吗?我有以下伪代码来显示我的大脑在哪里:
def recMatrix(x,y):
if(len(x) == 1):
return x[0] * y[0]
z = []
p1 = recMatrix2(x[0], (y[1] - y[3]));
p2 = recMatrix2(y[3], (x[0] + x[1]));
p3 = recMatrix2(y[0], (x[2] + x[3]));
p4 = recMatrix2(x[3], (y[2] - y[0]));
p5 = recMatrix2((x[0] + x[3]), (y[0] + y[3]));
p6 = recMatrix2((x[1] - x[3]), (y[2] + y[3]));
p7 = recMatrix2((x[0] - x[3]), (y[0] + y[3]));
z[0] = p5 + p4 - p2 + p6;
z[1] = p1 + p2;
z[2] = p3 + p4;
z[3] = p1 + p5 - p3 - p7;
return z
答案 0 :(得分:0)
最后一段代码的问题在于它没有采用正确的子矩阵。例如,在p1
中您想要采用x
的左上子矩阵,但您使用的是x[0]
,它只是x
的第一行。您需要一些代码将矩阵分成4个较小的矩阵。或者您可以使用像numpy这样的数学库:
In [14]: from numpy import *
In [15]: x=range(16)
In [16]: x=reshape(x,(4,4))
In [17]: x
Out[17]:
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
In [18]: x[0:2,0:2]
Out[18]:
array([[0, 1],
[4, 5]])
In [19]: x[2:4,2:4]
Out[19]:
array([[10, 11],
[14, 15]])
答案 1 :(得分:0)
找到了一个可以完成我正在寻找的实现...即,它具体显示了如何/何时递归:https://github.com/MartinThoma/matrix-multiplication/blob/master/Python/strassen-algorithm.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
from optparse import OptionParser
from math import ceil, log
def read(filename):
lines = open(filename, 'r').read().splitlines()
A = []
B = []
matrix = A
for line in lines:
if line != "":
matrix.append(map(int, line.split("\t")))
else:
matrix = B
return A, B
def printMatrix(matrix):
for line in matrix:
print "\t".join(map(str,line))
def add(A, B):
n = len(A)
C = [[0 for j in xrange(0, n)] for i in xrange(0, n)]
for i in xrange(0, n):
for j in xrange(0, n):
C[i][j] = A[i][j] + B[i][j]
return C
def subtract(A, B):
n = len(A)
C = [[0 for j in xrange(0, n)] for i in xrange(0, n)]
for i in xrange(0, n):
for j in xrange(0, n):
C[i][j] = A[i][j] - B[i][j]
return C
def strassenR(A, B):
""" Implementation of the strassen algorithm, similar to
http://en.wikipedia.org/w/index.php?title=Strassen_algorithm&oldid=498910018#Source_code_of_the_Strassen_algorithm_in_C_language
"""
n = len(A)
# Trivial Case: 1x1 Matrices
if n == 1:
return [[A[0][0]*B[0][0]]]
else:
# initializing the new sub-matrices
newSize = n/2
a11 = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
a12 = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
a21 = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
a22 = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
b11 = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
b12 = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
b21 = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
b22 = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
aResult = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
bResult = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
# dividing the matrices in 4 sub-matrices:
for i in xrange(0, newSize):
for j in xrange(0, newSize):
a11[i][j] = A[i][j]; # top left
a12[i][j] = A[i][j + newSize]; # top right
a21[i][j] = A[i + newSize][j]; # bottom left
a22[i][j] = A[i + newSize][j + newSize]; # bottom right
b11[i][j] = B[i][j]; # top left
b12[i][j] = B[i][j + newSize]; # top right
b21[i][j] = B[i + newSize][j]; # bottom left
b22[i][j] = B[i + newSize][j + newSize]; # bottom right
# Calculating p1 to p7:
aResult = add(a11, a22)
bResult = add(b11, b22)
p1 = strassen(aResult, bResult) # p1 = (a11+a22) * (b11+b22)
aResult = add(a21, a22) # a21 + a22
p2 = strassen(aResult, b11) # p2 = (a21+a22) * (b11)
bResult = subtract(b12, b22) # b12 - b22
p3 = strassen(a11, bResult) # p3 = (a11) * (b12 - b22)
bResult = subtract(b21, b11) # b21 - b11
p4 =strassen(a22, bResult) # p4 = (a22) * (b21 - b11)
aResult = add(a11, a12) # a11 + a12
p5 = strassen(aResult, b22) # p5 = (a11+a12) * (b22)
aResult = subtract(a21, a11) # a21 - a11
bResult = add(b11, b12) # b11 + b12
p6 = strassen(aResult, bResult) # p6 = (a21-a11) * (b11+b12)
aResult = subtract(a12, a22) # a12 - a22
bResult = add(b21, b22) # b21 + b22
p7 = strassen(aResult, bResult) # p7 = (a12-a22) * (b21+b22)
# calculating c21, c21, c11 e c22:
c12 = add(p3, p5) # c12 = p3 + p5
c21 = add(p2, p4) # c21 = p2 + p4
aResult = add(p1, p4) # p1 + p4
bResult = add(aResult, p7) # p1 + p4 + p7
c11 = subtract(bResult, p5) # c11 = p1 + p4 - p5 + p7
aResult = add(p1, p3) # p1 + p3
bResult = add(aResult, p6) # p1 + p3 + p6
c22 = subtract(bResult, p2) # c22 = p1 + p3 - p2 + p6
# Grouping the results obtained in a single matrix:
C = [[0 for j in xrange(0, n)] for i in xrange(0, n)]
for i in xrange(0, newSize):
for j in xrange(0, newSize):
C[i][j] = c11[i][j]
C[i][j + newSize] = c12[i][j]
C[i + newSize][j] = c21[i][j]
C[i + newSize][j + newSize] = c22[i][j]
return C
def strassen(A, B):
assert type(A) == list and type(B) == list
assert len(A) == len(A[0]) == len(B) == len(B[0])
nextPowerOfTwo = lambda n: 2**int(ceil(log(n,2)))
n = len(A)
m = nextPowerOfTwo(n)
APrep = [[0 for i in xrange(m)] for j in xrange(m)]
BPrep = [[0 for i in xrange(m)] for j in xrange(m)]
for i in xrange(n):
for j in xrange(n):
APrep[i][j] = A[i][j]
BPrep[i][j] = B[i][j]
CPrep = strassenR(APrep, BPrep)
C = [[0 for i in xrange(n)] for j in xrange(n)]
for i in xrange(n):
for j in xrange(n):
C[i][j] = CPrep[i][j]
return C
if __name__ == "__main__":
parser = OptionParser()
parser.add_option("-i", dest="filename", default="2000.in",
help="input file with two matrices", metavar="FILE")
(options, args) = parser.parse_args()
A, B = read(options.filename)
C = strassen(A, B)
printMatrix(C)