我有一个数据生成类,它可以批量生成数据。简化如下:
import numpy as np
import os
import psutil
def memory_check():
pid = os.getpid()
py_mem = psutil.Process(pid)
memory_use = py_mem.memory_info()[0] / 2. ** 30
return {"python_usage": memory_use}
class DataBatcher:
def __init__(self, X, batch_size):
self.X = X
self.start = 0
self.batch_size = batch_size
self.row_dim, col_dim = X.shape
self.batch = np.zeros((batch_size, col_dim))
def gen_batch(self):
end_index = self.start + self.batch_size
if end_index < self.row_dim:
indices = range(self.start, end_index)
print("before assign batch \n", memory_check())
self.batch[:] = self.X.take(indices, axis=0, mode='wrap')
print("after assign batch \n", memory_check())
self.start = end_index
return self.batch
if __name__ == "__main__":
X = np.random.sample((1000000, 50))
for i in range(100):
data_batcher = DataBatcher(X, 5000)
x = data_batcher.gen_batch()
实际代码与上述代码非常接近,只是self.X是在DataBatcher
类内的另一种方法中生成的,并且会定期更新。我注意到,在对self.X进行任何更改时,Python的内存使用量在self.batch[:] = self.X.take(indices, axis=0, mode='wrap')
行的每一轮都稳定增长。我认为不应该这样,因为我已经为self.batch
预分配了内存?
答案 0 :(得分:0)
正如Why does numpy.zeros takes up little space中回答的那样,这种令人惊讶的行为可能是某些操作系统级别的优化:np.zeros
实际上并不会占用您用self.batch[:] = self.X.take(indices, axis=0, mode='wrap')
有效地写在上面的内存>