numpy中以下操作的含义是什么?

时间:2015-04-04 20:49:49

标签: python numpy matrix machine-learning matrix-multiplication

我正在挖掘出一块numpy代码,而且我根本就不了解这条代码:

W[:, :, None] * h[None, :, :] * diff[:, None, :]

其中 W h 差异是784x20,20x100和784x100矩阵。乘法结果是784x20x100数组,但我不知道这个计算实际上是做什么的,结果是什么意思。

对于它的价值,该线来自机器学习相关代码, W 对应于神经网络层的权重数组, h 是图层激活, diff 是网络目标和假设之间的差异(来自变换自动编码器的Sida Wang's thesis)。

1 个答案:

答案 0 :(得分:5)

对于NumPy数组,*对应于逐元素乘法。为了使其工作,两个数组必须是:

  • 彼此相同的形状
  • 这样一个阵列可以broadcast到另一个

如果在配对每个数组的尾随尺寸时,每个数组中的长度相等或其中一个长度为1,则可以将一个数组广播到另一个数组。

例如,以下数组AB具有与广播兼容的形状:

A.shape == (20, 1, 3)
B.shape ==     (4, 3)

3等于3,然后A中的下一个长度为1,可以与任意长度配对。这并不重要B的维度少于A。)

为了使两个不兼容的数组彼此可广播,可以将额外的维度插入到一个或两个数组中。使用Nonenp.newaxis为维度建立索引会将长度为1的额外维度插​​入数组中。


让我们看一下问题中的例子。 Python评估从左到右的重复乘法:

  • W[:, :, None]的形状为(784, 20, 1)
  • h[None, :, :]的形状为( 1, 20, 100)

根据上面的解释,这些形状是可播放的,并且乘法返回一个形状为(784, 20, 100)的数组。

  • 上次乘法的数组形状(784, 20, 100)
  • diff[:, None, :]的形状为(784, 1, 100)

这两个数组的这些形状是兼容的,因此第二次乘法成功。返回形状为(784, 20, 100)的数组。