python中矩阵乘法的多线程

时间:2013-03-18 23:43:05

标签: python multithreading

我想创建n个线程,每个线程计算结果矩阵的整行。我试过以下代码,

import numpy
import random
import threading

class MatrixMult(threading.Thread):
    """A thread which computes the i,j entry of A * B"""
    def __init__(self, A, B, i):
        super(MatrixMult, self).__init__()
        self.A = A
        self.B = B
        self.i = i
        #self.j = j                                                             
    def run(self):
        print "Computing %i, %i" % (self.i, self.i)
        x = 0
        result=[]
        for k in range(self.A.shape[0])
            x += self.A[self.i,k] * self.B[k,self.i  
        self.result=x
        print "Completed %i, %i" % (self.i, self.j)

def mult(n):
    """A function to randomly create two n x n matrices and multiply them"""
    # Create two random matrices                                                
    A = numpy.zeros((n,n))
    B = numpy.zeros((n,n))
    for i in range(n):
        for j in range(n):
            A[i,j] = random.randrange(0, 100)
            B[i,j] = random.randrange(0, 100)
    # Create and start the threads                                              
    threads = []
    for i in range(n):
       # for j in range(n):                                                     
        t = MatrixMult(A, B, i)
            threads.append(t)
            t.start()
    for t in threads: t.join()
    C = numpy.zeros((n,n))
    for t in threads:
        C[t.i] = t.result
    return C
print multi(30)
然而,它打印出许多奇怪的矩阵:

[ 66695.  66695.  66695.  66695.  66695.  66695.  66695.  66695.  66695.
   66695.  66695.  66695.  66695.  66695.  66695.  66695.  66695.  66695.
   66695.  66695.  66695.  66695.  66695.  66695.  66695.  66695.  66695.
   66695.  66695.  66695.]
 [ 88468.  88468.  88468.  88468.  88468.  88468.  88468.  88468.  88468.
   88468.  88468.  88468.  88468.  88468.  88468.  88468.  88468.  88468.
   88468.  88468.  88468.  88468.  88468.  88468.  88468.  88468.  88468.
   88468.  88468.  88468.]]

有人在我的代码中看到了问题吗?我不明白我做错了什么。

1 个答案:

答案 0 :(得分:2)

您的代码集

C[t.i] = t.result

C的整行设置为值t.result,这是一个标量。我在那里看到一些关于j的注释内容;你可能想要考虑到这一点,并且还要修复

x += self.A[self.i,k] * self.B[k,self.i

使用j(也不是语法错误)。原样,您似乎正在计算C[i, i],然后将该值分配给整行。

另外:你知道这段代码保证比np.dot慢得多,对吧?在python中进行紧密循环之间,尽管GIL在线程之间分配计算工作,并且首先是an inefficient algorithm for matrix multiplication。如果您的目标实际上是使用多个核心加速矩阵乘法,将您的numpy链接到MKL,OpenBLAS或ACML,请使用np.dot,并将其称为一天。