如何使Strassen算法(矩阵多重飞行)更有效?

时间:2018-10-18 14:19:45

标签: python matrix multiplying strassen

结果是Strassen():27秒和muliply_mat():5秒

我怎么能比multiple_mat()更短?

我猜问题是strassen函数中的new_mat函数。

我应该采取什么措施减少strassen函数的时间。

def new_mat(a,b):
matrix = [[0 for row in range(a)]for col in range(b)]
return matrix

def add_mat(a, b):
    if type(a) == int:
        d = a + b
    else:
        d = []
        for i in range(len(a)):
            c = []
            for j in range(len(a[0])):
                c.append(a[i][j] + b[i][j])
            d.append(c)
    return d

def sub_mat(a, b):
    if type(a) == int:
        d = a - b
    else:
        d = []
        for i in range(len(a)):
            c = []
            for j in range(len(a[0])):
                c.append(a[i][j] - b[i][j])
            d.append(c)
    return d
def multiply_mat(a,b):
    if len(a[0]) != len(b):
        return("Error")
    else:
        result_mat = new_mat(len(a[0]), len(b))
        for i in range(len(a[0])):
            for j in range(len(b)):
                for k in range(len(b)):
                    result_mat[i][j] += a[i][k] * b[k][j]
        return(result_mat)

def split(matrix):
    a = matrix
    b = matrix
    c = matrix
    d = matrix
    while(len(a) > len(matrix)//2):
        a = a[:len(a)//2]
        b = b[:len(b)//2]
        c = c[len(c)//2:]
        d = d[len(d)//2:]
    while(len(a[0]) > len(matrix[0])//2):
        for i in range(len(a[0])//2):
            a[i] = a[i][:len(a[i])//2]
            b[i] = b[i][len(b[i])//2:]
            c[i] = c[i][:len(c[i])//2]
            d[i] = d[i][len(d[i])//2:]
    return a,b,c,d
def strassen(a,b,n):
    if n <= 2:
        return(multiply_mat(a,b))
    else : 
        a11,a12,a21,a22 = split(a)
        b11,b12,b21,b22 = split(b)

        # p1 = (a11+a22) * (b11+b22)
        m1 = strassen(add_mat(a11,a22), add_mat(b11,b22), n//2)

        # p2 = (a21+a22) * b11
        m2 = strassen(add_mat(a21,a22), b11, n//2)

        # p3 = a11 * (b12-b22)
        m3 = strassen(a11, sub_mat(b12,b22), n//2)

        # p4 = a22 * (b12-b11)
        m4 = strassen(a22, sub_mat(b21,b11), n//2)

        # p5 = (a11+a12) * b22
        m5 = strassen(add_mat(a11,a12), b22, n//2)

        # p6 = (a21-a11) * (b11+b12)
        m6 = strassen(sub_mat(a21,a11), add_mat(b11,b12), n//2)

        # p7 = (a12-a22) * (b21+b22)
        m7 = strassen(sub_mat(a12,a22), add_mat(b21,b22), n//2)

        # c11 = p1 + p4 - p5 + p7
        c11 = add_mat(sub_mat(add_mat(m1, m4), m5), m7)

        # c12 = p3 + p5
        c12 = add_mat(m3, m5)

        # c21 = p2 + p4
        c21 = add_mat(m2, m4)

        # c22 = p1 + p3 - p2 + p6
        c22 = add_mat(sub_mat(add_mat(m1, m3), m2), m6)

        c = new_mat(len(c11)*2,len(c11)*2)
        for i in range(len(c11)):
            for j in range(len(c11)):
                c[i][j]                   = c11[i][j]
                c[i][j+len(c11)]          = c12[i][j]
                c[i+len(c11)][j]          = c21[i][j]
                c[i+len(c11)][j+len(c11)] = c22[i][j]
        return c

a = new_mat(256,256)
b = new_mat(256,256)
import time
start_time = time.time()
tmp = strassen(a,b,256)
print("--- %s seconds ---" %(time.time() - start_time))

start_time2 = time.time()
tmp2 = multiply_mat(a,b)
print("--- %s seconds ---" %(time.time() - start_time2))

我想这段代码会花很长时间。

c = new_mat(len(c11)*2,len(c11)*2)
for i in range(len(c11)):
    for j in range(len(c11)):
        c[i][j]                   = c11[i][j]
        c[i][j+len(c11)]          = c12[i][j]
        c[i+len(c11)][j]          = c21[i][j]
        c[i+len(c11)][j+len(c11)] = c22[i][j]
return c

我如何从strassen()中摆脱出来?

0 个答案:

没有答案