为什么np.genfromtxt()最初会占用大量数据集的大量内存?

时间:2018-03-02 19:53:11

标签: python numpy

我有一个包含450,000列和450行的数据集 - 所有数值。我将数据集加载到具有np.genfromtxt()函数的NumPy数组中:

# The skip_header skips over the column names, which is the first row in the file
train = np.genfromtxt('train_data.csv', delimiter=',', skip_header=1)

train_labels = train[:, -1].astype(int)
train_features = train[:, :-1]

当函数最初加载数据集时,它使用超过15-20 GB的RAM。但是,在函数完成运行后,它将降低到仅使用2-3 GB的RAM。为什么np.genfromtxt()最初耗尽了这么多内存?

2 个答案:

答案 0 :(得分:0)

如果您提前知道数组的大小,可以通过在解析目标数组时将每一行加载到目标数组中来节省时间和空间。

例如:

In [173]: txt="""1,2,3,4,5,6,7,8,9,10
     ...: 2,3,4,5,6,7,8,9,10,11
     ...: 3,4,5,6,7,8,9,10,11,12
     ...: """

In [174]: np.genfromtxt(txt.splitlines(),dtype=int,delimiter=',',encoding=None)
Out[174]: 
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])

使用更简单的解析功能:

In [177]: def foo(txt,size):
     ...:     out = np.empty(size, int)
     ...:     for i,line in enumerate(txt):
     ...:        out[i,:] = line.split(',')
     ...:     return out
     ...: 
In [178]: foo(txt.splitlines(),(3,10))
Out[178]: 
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])

out[i,:] = line.split(',')将字符串列表加载到数字dtype数组中会强制进行转换,与np.array(line..., dtype=int)相同。

In [179]: timeit np.genfromtxt(txt.splitlines(),dtype=int,delimiter=',',encoding
     ...: =None)
266 µs ± 427 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [180]: timeit foo(txt.splitlines(),(3,10))
19.2 µs ± 169 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

更简单,直接的解析器要快得多。

但是,如果我尝试使用loadtxtgenfromtxt的简化版本:

In [184]: def bar(txt):
     ...:     alist=[]
     ...:     for i,line in enumerate(txt):
     ...:        alist.append(line.split(','))
     ...:     return np.array(alist, dtype=int)
     ...: 
     ...: 
In [185]: bar(txt.splitlines())
Out[185]: 
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])
In [186]: timeit bar(txt.splitlines())
13 µs ± 20.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

对于这个小案例,它甚至更快。 genfromtxt必须有很多解析开销。这是一个小样本,因此内存消耗并不重要。

为了完整性,loadtxt

In [187]: np.loadtxt(txt.splitlines(),dtype=int,delimiter=',')
Out[187]: 
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])
In [188]: timeit np.loadtxt(txt.splitlines(),dtype=int,delimiter=',')
103 µs ± 50.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

fromiter

In [206]: def g(txt):
     ...:     for row in txt:
     ...:         for item in row.split(','):
     ...:             yield item
In [209]: np.fromiter(g(txt.splitlines()),dtype=int).reshape(3,10)
Out[209]: 
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
       [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12]])
In [210]: timeit np.fromiter(g(txt.splitlines()),dtype=int).reshape(3,10)
12.3 µs ± 21.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

答案 1 :(得分:0)

@Kasramvd在评论中提出了一个很好的建议,以研究提出的解决方案here。该答案的iter_loadtxt()解决方案证明是我问题的完美解决方案:

def iter_loadtxt(filename, delimiter=',', skiprows=0, dtype=float):
    def iter_func():
        with open(filename, 'r') as infile:
            for _ in range(skiprows):
                next(infile)
            for line in infile:
                line = line.rstrip().split(delimiter)
                for item in line:
                    yield dtype(item)
        iter_loadtxt.rowlength = len(line)

    data = np.fromiter(iter_func(), dtype=dtype)
    data = data.reshape((-1, iter_loadtxt.rowlength))
    return data

genfromtxt()占用如此多内存的原因是因为它在解析数据文件时没有将数据存储在高效的NumPy数组中,因此在NumPy解析我的大数据文件时会占用过多的内存。 / p>