结果是Strassen():27秒和muliply_mat():5秒
我怎么能比multiple_mat()更短?
我猜问题是strassen函数中的new_mat函数。
我应该采取什么措施减少strassen函数的时间。
def new_mat(a,b):
matrix = [[0 for row in range(a)]for col in range(b)]
return matrix
def add_mat(a, b):
if type(a) == int:
d = a + b
else:
d = []
for i in range(len(a)):
c = []
for j in range(len(a[0])):
c.append(a[i][j] + b[i][j])
d.append(c)
return d
def sub_mat(a, b):
if type(a) == int:
d = a - b
else:
d = []
for i in range(len(a)):
c = []
for j in range(len(a[0])):
c.append(a[i][j] - b[i][j])
d.append(c)
return d
def multiply_mat(a,b):
if len(a[0]) != len(b):
return("Error")
else:
result_mat = new_mat(len(a[0]), len(b))
for i in range(len(a[0])):
for j in range(len(b)):
for k in range(len(b)):
result_mat[i][j] += a[i][k] * b[k][j]
return(result_mat)
def split(matrix):
a = matrix
b = matrix
c = matrix
d = matrix
while(len(a) > len(matrix)//2):
a = a[:len(a)//2]
b = b[:len(b)//2]
c = c[len(c)//2:]
d = d[len(d)//2:]
while(len(a[0]) > len(matrix[0])//2):
for i in range(len(a[0])//2):
a[i] = a[i][:len(a[i])//2]
b[i] = b[i][len(b[i])//2:]
c[i] = c[i][:len(c[i])//2]
d[i] = d[i][len(d[i])//2:]
return a,b,c,d
def strassen(a,b,n):
if n <= 2:
return(multiply_mat(a,b))
else :
a11,a12,a21,a22 = split(a)
b11,b12,b21,b22 = split(b)
# p1 = (a11+a22) * (b11+b22)
m1 = strassen(add_mat(a11,a22), add_mat(b11,b22), n//2)
# p2 = (a21+a22) * b11
m2 = strassen(add_mat(a21,a22), b11, n//2)
# p3 = a11 * (b12-b22)
m3 = strassen(a11, sub_mat(b12,b22), n//2)
# p4 = a22 * (b12-b11)
m4 = strassen(a22, sub_mat(b21,b11), n//2)
# p5 = (a11+a12) * b22
m5 = strassen(add_mat(a11,a12), b22, n//2)
# p6 = (a21-a11) * (b11+b12)
m6 = strassen(sub_mat(a21,a11), add_mat(b11,b12), n//2)
# p7 = (a12-a22) * (b21+b22)
m7 = strassen(sub_mat(a12,a22), add_mat(b21,b22), n//2)
# c11 = p1 + p4 - p5 + p7
c11 = add_mat(sub_mat(add_mat(m1, m4), m5), m7)
# c12 = p3 + p5
c12 = add_mat(m3, m5)
# c21 = p2 + p4
c21 = add_mat(m2, m4)
# c22 = p1 + p3 - p2 + p6
c22 = add_mat(sub_mat(add_mat(m1, m3), m2), m6)
c = new_mat(len(c11)*2,len(c11)*2)
for i in range(len(c11)):
for j in range(len(c11)):
c[i][j] = c11[i][j]
c[i][j+len(c11)] = c12[i][j]
c[i+len(c11)][j] = c21[i][j]
c[i+len(c11)][j+len(c11)] = c22[i][j]
return c
a = new_mat(256,256)
b = new_mat(256,256)
import time
start_time = time.time()
tmp = strassen(a,b,256)
print("--- %s seconds ---" %(time.time() - start_time))
start_time2 = time.time()
tmp2 = multiply_mat(a,b)
print("--- %s seconds ---" %(time.time() - start_time2))
我想这段代码会花很长时间。
c = new_mat(len(c11)*2,len(c11)*2)
for i in range(len(c11)):
for j in range(len(c11)):
c[i][j] = c11[i][j]
c[i][j+len(c11)] = c12[i][j]
c[i+len(c11)][j] = c21[i][j]
c[i+len(c11)][j+len(c11)] = c22[i][j]
return c
我如何从strassen()中摆脱出来?