我收到了MemoryError
numpy.where
,但我不确定原因。我不能在这里发布实际代码,但下面是一个复制问题的小工作示例。
import numpy as np
dat = np.random.randn(100000, 1, 1, 1, 45, 2, 3)
# The following two steps seem superfluous but I wanted to replicate
# behaviour in the original code
cond = dat[:,0,0,0,0,0,0] > 0
cond = cond[:,None,None,None,None,None,None]
dat2 = np.where(cond, dat, 0)
dat[...,2] = np.where(cond, dat[...,2], dat2[...,2]) # Causes MemoryError
我知道为我的计算机添加更多内存可以解决问题,但我想了解这里发生了什么。
我希望上面的数组切片不会复制数组但只返回一个视图,但我认为它实际上是出于某种原因复制数组。
答案 0 :(得分:1)
此处没有“魔法”,您使用np.random.randn(100000, 1, 1, 1, 45, 2, 3)
创建的数据阵列非常大。
Numpy似乎将每个数字存储为64位(8字节)浮点数,因此您的数组占用大约206兆字节的内存(100000 * 1 * 1 * 1 * 45 * 2 * 3 * 8)。
/usr/bin/time -v python test.py
表示该程序在其峰值处使用大约580 MB,这可能是由于复制了该对象。