此numpy操作会出现内存错误。 (这里X和Y是形状分别为(5000,3072)和(500,3072)的二维数组)
dists[:,:] = np.sqrt(np.sum(np.square(np.subtract(X, Y[:,np.newaxis])), axis=2))
我认为Numpy数组广播占用了大量内存。有什么方法可以优化这些阵列操作的内存使用量?
编辑:
(这是cs231n的分配1中的问题)。我找到了另一种解决方案,可以提供相同的结果而不会出现内存错误:
dists[:,:] = np.sqrt((Y**2).sum(axis=1)[:, np.newaxis] + (X**2).sum(axis=1) - 2 * Y.dot(X.T))
您能帮助我理解为什么我的解决方案在内存方面效率低下吗?