Numpy ndarray乘法切换到矩阵乘法

时间:2016-01-31 18:33:01

标签: python numpy multidimensional-array matrix-multiplication

这是我得到的错误:

File "/data/eduardoj/linear.py", line 305, in _fit_model
        de_dl = (dl_dt + de_dt) * dt_dl
      File "/data/eduardoj/MSc-env/lib/python3.4/site-packages/numpy/matrixlib/defmatrix.py", line 343, in __mul__
        return N.dot(self, asmatrix(other))
    ValueError: shapes (1,53097) and (1,53097) not aligned: 53097 (dim 1) != 1 (dim 0)

这是崩溃的numpy代码:

   340     def __mul__(self, other):                                           
>* 341         if isinstance(other, (N.ndarray, list, tuple)) :                
   342             # This promotes 1-D vectors to row vectors                  
   343             return N.dot(self, asmatrix(other))                         
   344         if isscalar(other) or not hasattr(other, '__rmul__') :          
   345             return N.dot(self, other)                                   
   346         return NotImplemented  

(那里有一个断点>*

在我的脚本中,我有一个包含以下行的循环:

de_dl = (dloss_dt + dr_dt) * dt_dl

de_dldloss_dtdr_dtdt_dl的预期类型和形状为:

ndarray float32 (1, 53097)

所以,我只是想计算元素乘法。我正在使用pudb3来调试我的脚本。我在第一次迭代(i == 0)中检查了它的效果很好(最初生成零)。我注意到,对于第一次迭代,线程 NOT 到达我设置的断点。在下一次迭代(i==1)中,我决定在调用乘法之前停下来,以确保dloss_dtdr_dtdt_dl的类型和形状是还是一样。他们是。

虽然它们是相同的,但似乎程序经历了一系列不同的步骤,并且有些步骤在N.dot乘法中结束。

所以,我要求任何可能阻止我操作一个简单的元素乘法的线索。

1 个答案:

答案 0 :(得分:1)

您在问题中显示的numpy源代码片段与__mul__ method of np.matrixlib.defmatrix.matrix对应。

当你将矩阵对象与另一个矩阵或数组右乘时,会调用此方法,即A * B等于A.__mul__(B)如果A是一个矩阵(它没有'无论B是ndarray还是矩阵。

方法可以可能被调用的唯一方法是(dl_dt + de_dt)是矩阵(或其他一些派生矩阵类)而不是ndarray。因此,dl_dtde_dt正在转换为代码中的某个矩阵。

由于在第一次迭代时未调用__mul__,因此必须在代码中的第305行(de_dl = (dl_dt + de_dt) * dt_dl)之后的某处发生。