我有几个.npz
个文件。所有.npz
文件都使用相同的结构:每个结构只包含两个变量,总是使用相同的变量名。截至目前,我只是循环遍历所有.npz
文件,检索两个变量值并将它们附加到某个全局变量中:
# Let's assume there are 100 npz files
x_train = []
y_train = []
for npz_file_number in range(100):
data = dict(np.load('{0:04d}.npz'.format(npz_file_number)))
x_train.append(data['x'])
y_train.append(data['y'])
需要一段时间,瓶颈就是CPU。将x
和y
变量附加到x_train
和y_train
变量的顺序无关紧要。
有没有办法在多线程中加载多个.npz
文件?
答案 0 :(得分:2)
我对@Brent Washburne的评论感到惊讶,并决定自己尝试一下。我认为一般问题是双重的:
首先,读取数据通常是IO绑定的,因此编写多线程代码通常不会产生高性能提升。其次,由于语言本身的设计,在python中进行共享内存并行化本身就很困难。与本机c相比,有更多的开销。
但是,让我们看看我们能做些什么。
# some imports
import numpy as np
import glob
from multiprocessing import Pool
import os
# creating some temporary data
tmp_dir = os.path.join('tmp', 'nptest')
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
for i in range(100):
x = np.random.rand(10000, 50)
file_path = os.path.join(tmp_dir, '%05d.npz' % i)
np.savez_compressed(file_path, x=x)
def read_x(path):
with np.load(path) as data:
return data["x"]
def serial_read(files):
x_list = list(map(read_x, files))
return x_list
def parallel_read(files):
with Pool() as pool:
x_list = pool.map(read_x, files)
return x_list
好的,准备好的东西。让我们来看看时间。
files = glob.glob(os.path.join(tmp_dir, '*.npz'))
%timeit x_serial = serial_read(files)
# 1 loops, best of 3: 7.04 s per loop
%timeit x_parallel = parallel_read(files)
# 1 loops, best of 3: 3.56 s per loop
np.allclose(x_serial, x_parallel)
# True
它实际上看起来像是一个不错的加速。我正在使用两个真实和两个超线程核心。
要一次运行并计时,您可以执行以下脚本:
from __future__ import print_function
from __future__ import division
# some imports
import numpy as np
import glob
import sys
import multiprocessing
import os
import timeit
# creating some temporary data
tmp_dir = os.path.join('tmp', 'nptest')
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
for i in range(100):
x = np.random.rand(10000, 50)
file_path = os.path.join(tmp_dir, '%05d.npz' % i)
np.savez_compressed(file_path, x=x)
def read_x(path):
data = dict(np.load(path))
return data['x']
def serial_read(files):
x_list = list(map(read_x, files))
return x_list
def parallel_read(files):
pool = multiprocessing.Pool(processes=4)
x_list = pool.map(read_x, files)
return x_list
files = glob.glob(os.path.join(tmp_dir, '*.npz'))
#files = files[0:5] # to test on a subset of the npz files
# Timing:
timeit_runs = 5
timer = timeit.Timer(lambda: serial_read(files))
print('serial_read: {0:.4f} seconds averaged over {1} runs'
.format(timer.timeit(number=timeit_runs) / timeit_runs,
timeit_runs))
# 1 loops, best of 3: 7.04 s per loop
timer = timeit.Timer(lambda: parallel_read(files))
print('parallel_read: {0:.4f} seconds averaged over {1} runs'
.format(timer.timeit(number=timeit_runs) / timeit_runs,
timeit_runs))
# 1 loops, best of 3: 3.56 s per loop
# Examples of use:
x = serial_read(files)
print('len(x): {0}'.format(len(x))) # len(x): 100
print('len(x[0]): {0}'.format(len(x[0]))) # len(x[0]): 10000
print('len(x[0][0]): {0}'.format(len(x[0][0]))) # len(x[0]): 10000
print('x[0][0]: {0}'.format(x[0][0])) # len(x[0]): 10000
print('x[0].nbytes: {0} MB'.format(x[0].nbytes / 1e6)) # 4.0 MB