使用切片的矩阵乘法。蟒蛇

时间:2015-03-04 21:37:35

标签: python slice matrix-multiplication

我有以下代码:

from numpy import *
a = random.rand(3,4)
b = random.rand(4,2)
c = linspace(0,0,6)
c.shape = (3,2)
for i in range(a.shape[0]): 
  for j in range(b.shape[1]): 
      for k in range(b.shape[0]): 
         c[i][j] += a[i][k] * b[k][j]
for r in c: 
  print "C = ", r

但是我需要改变最后一个(最里面的)循环,我需要使用切片。据我所知,我必须做这样的事情:

for i in range(a.shape[0]): 
 for j in range(b.shape[1]): 
     c[i][j] += a[i][0:l-1] * b[0:l-1][j]

但不幸的是,它并没有奏效。有人可以帮助我,并提示如何做到这一点吗?

2 个答案:

答案 0 :(得分:1)

您在此处尝试执行的操作是来自a的行向量和来自b的列向量的点积:

c[i][j] += a[i][0:l-1] * b[0:l-1][j]

将是

c[i][j] = np.dot(a[i], b[:][j]) 

相同
sum([a_*b_ for a_,b_ in zip(a[i],b[:][j])])

sum(a[i]*b[:][j])

但速度更快。

但是,如果您正在使用np.dot,则无论如何:

c = np.dot(a,b)

肯定更快。

答案 1 :(得分:0)

让我们从一个帮助函数开始,该函数创建一个r行的ListOfLists(一个lol),每行包含c列:

In [1]: def lol(r,c): return [[i*c+j for j in range(c)] for i in range(r)]

并创建两个列表列表

In [2]: a = lol(2,5) ; b = lol(5,4)

我们要验证下面的代码,意味着使用两个lol的矩阵产品是否正常工作,因此我们从ndarraya创建两个b并形成他们的代码内在产品

In [3]: from numpy import array

In [4]: aa = array(a) ; ab = array(b) ; aa.dot(ab)
Out[4]: 
array([[120, 130, 140, 150],
       [320, 355, 390, 425]])

现在,我们可以测试两个lols的内部或矩阵产品的代码

In [5]: [[sum(x*y for x, y in zip(ar,bc)) for bc in zip(*b)] for ar in a]
Out[5]: [[120, 130, 140, 150], [320, 355, 390, 425]]

没关系,不是吗? (我不得不说在代码的第一次迭代中我得出了结果的转置......)。

现在我们有了一点信心,让我们尝试一些更实质的东西

In [6]: a = lol(200,50) ; b = lol(50,400)

In [7]: aa = array(a) ; ab = array(b)

In [8]: %timeit c = aa.dot(ab)
100 loops, best of 3: 4.53 ms per loop

In [9]: %timeit c = [[sum(x*y for x, y in zip(ar,bc)) for bc in zip(*b)] for ar in a]
1 loops, best of 3: 469 ms per loop

正如您所看到的,numpy比列表上的操作快两个数量级,但在OP问题的上下文中,在ndarray上尝试我们的列表代码更有趣S:

In [10]: %timeit c = [[sum(x*y for x, y in zip(ar,bc)) for bc in zip(*ab)] for ar in aa] 
1 loops, best of 3: 1.32 s per loop

哦,如果你有numpy数组,那么使用数组方法而不是对单个元素进行操作会更好......但等等,我们有一个更快的替代内部zip:< / p>

In [11]: %timeit c = [[sum(x*y for x, y in zip(ar,bc)) for bc in ab.T] for ar in aa]
1 loops, best of 3: 1.34 s per loop

In [12]: 

不,即使我们使用ndarray的转置属性,我们也会得到相同的结果。

总结:永远不要使用单独访问的numpy数组元素来进行繁重的计算提升......


感谢ipython及其%timeit 魔法,这让我更轻松有趣(对我而言)。