我正在挖掘出一块numpy代码,而且我根本就不了解这条代码:
W[:, :, None] * h[None, :, :] * diff[:, None, :]
其中 W , h 和差异是784x20,20x100和784x100矩阵。乘法结果是784x20x100数组,但我不知道这个计算实际上是做什么的,结果是什么意思。
对于它的价值,该线来自机器学习相关代码, W 对应于神经网络层的权重数组, h 是图层激活, diff 是网络目标和假设之间的差异(来自变换自动编码器的Sida Wang's thesis)。
答案 0 :(得分:5)
对于NumPy数组,*
对应于逐元素乘法。为了使其工作,两个数组必须是:
如果在配对每个数组的尾随尺寸时,每个数组中的长度相等或其中一个长度为1,则可以将一个数组广播到另一个数组。
例如,以下数组A
和B
具有与广播兼容的形状:
A.shape == (20, 1, 3)
B.shape == (4, 3)
(3
等于3
,然后A
中的下一个长度为1
,可以与任意长度配对。这并不重要B
的维度少于A
。)
为了使两个不兼容的数组彼此可广播,可以将额外的维度插入到一个或两个数组中。使用None
或np.newaxis
为维度建立索引会将长度为1的额外维度插入数组中。
让我们看一下问题中的例子。 Python评估从左到右的重复乘法:
W[:, :, None]
的形状为(784, 20, 1)
h[None, :, :]
的形状为( 1, 20, 100)
根据上面的解释,这些形状是可播放的,并且乘法返回一个形状为(784, 20, 100)
的数组。
(784, 20, 100)
diff[:, None, :]
的形状为(784, 1, 100)
这两个数组的这些形状是兼容的,因此第二次乘法成功。返回形状为(784, 20, 100)
的数组。