从文件中读取numpy数组并解析非常慢

时间:2017-11-21 09:40:31

标签: python numpy numpy-memmap fromfile

我有一个二进制文件,我将其解析为Python中的numpy数组,如下所示:

bytestream= np.fromfile(path, dtype=np.int16)

 for a in range(sizeA):
        for x in range(0, sizeX):
            for y in range(0, sizeY):
                for z in range(0, sizeZ):
                    parsed[a, x, y, z] = bytestream[z + (sizeZ * x) + (sizeZ * sizeX * y) + (sizeZ * sizeX * sizeY * a)]

但是,这非常慢。谁能告诉我为什么以及如何加快它?

2 个答案:

答案 0 :(得分:1)

你似乎在你的代码中犯了一个错误,我相信假设行主要排序,x和y应该在(sizeZ * x) + (sizeZ * sizeX * y)中反转。在任何情况下,请检查下面的代码,该代码验证重塑是您想要的。它之所以慢的原因是嵌套的for循环。

在python中,for循环是一个非常复杂的构造,具有非常大的开销。因此,在大多数情况下,您应该避免使用循环并使用库提供的函数(它们也有for循环但在c / c ++中完成)。你会发现"删除for循环"在numpy中是一个常见的问题,因为大多数人会首先尝试一些他们以最直接的方式知道的算法(例如卷积,最大池)。并且意识到它非常缓慢并且基于numpy api寻找聪明的替代方案,其中大部分计算转移到c ++端而不是在python中发生。

import numpy as np

# gen some data 
arr= (np.random.random((4,4,4,4))*10).astype(np.int16)
arr.tofile('test.bin')

# original code
bytestream=np.fromfile('test.bin',dtype=np.int16)
parsed=np.zeros(arr.shape,dtype=np.int16)
sizeA,sizeX,sizeY,sizeZ=arr.shape
for a in range(sizeA):
    for x in range(0, sizeX):
        for y in range(0, sizeY):
            for z in range(0, sizeZ):
                parsed[a, x, y, z] = bytestream[z + (sizeZ * y) + (sizeZ * sizeX * x) + (sizeZ * sizeX * sizeY * a)]

print(np.allclose(arr,parsed))
print(np.allclose(arr,bytestream.reshape((sizeA,sizeX,sizeY,sizeZ))))

答案 1 :(得分:0)

你正在将一个单元格的numpy数组parsed更新一个单元格,不得不在python和每个单元格的numpy的C实现之间反弹。这是一个严重的开销。 (更不用说在zaw lin所说的每个python迭代中必须更新python变量ayxz的成本,以及计算成本的成本指数)

当您执行一些numpy C代码时,使用numpy.copynumpy.reshapenumpy.moveaxis来尽可能多地更新尽可能多的值。