我正在尝试进行多处理以加快特征提取过程。 这是我所做的:
import numpy as np
from multiprocessing import Pool, Process
import time
def cal_feature(subject):
return (np.mean((subject),axis=-1))
#multiprocessing
if __name__ == '__main__':
st=time.time()
data_1= np.random.randint(0, 100, size=(300, 100, 2000))
data_2= np.random.randint(100, 200, size=(300, 100, 2000))
data_3= np.random.randint(100, 200, size=(300, 100, 2000))
data_4= np.random.randint(100, 200, size=(300, 100, 2000))
data_5= np.random.randint(100, 200, size=(300, 100, 2000))
data={1:data_1,2:data_2,3:data_3,4:data_4,5:data_5}
p=Pool(10)
parallel_result=[]
for i in data.keys():
result=p.map(cal_feature, np.split(data[i], 10))
parallel=np.concatenate((result),axis=0)
parallel_result.append(parallel)
p.close()
p.join()
print('multprocessing total time',time.time()-st)
#Serial processing
st=time.time()
data_1= np.random.randint(0, 100, size=(300, 100, 2000))
data_2= np.random.randint(100, 200, size=(300, 100, 2000))
data_3= np.random.randint(100, 200, size=(300, 100, 2000))
data_4= np.random.randint(100, 200, size=(300, 100, 2000))
data={1:data_1,2:data_2,3:data_3,4:data_4,5:data_5}
series_result=[]
series=[]
for i in data.keys():
series_result.append(cal_feature(data[i]))
print('series toal time',time.time()-st)
但是,多处理所花的时间比串行编程要长5倍。如何加快特征提取速度?在这里,我使用numpy.mean
作为功能,但是在实际数据集中,我有30个复杂的功能。我有80个这样的数据集,而不是随机生成的5个数据集。有没有办法进行健壮的特征提取?
答案 0 :(得分:0)
重点是每个进程必须为每个 target 处理腌制相当大的数组-引起了轰动。而连续处理直接处理初始数组。有了您的输入数据,我就连续3
秒,多处理方法6
秒。
您尝试执行的问题:不要在每次迭代中都初始化pool.map
-进行一次map
调用,因为它已经暗示要处理可迭代的数据项。并且不需要拆分数据项np.split(data[i], 10)
,而是将data[i]
传递给目标函数。
话虽如此,要在您的情况下从multiprocessing
方法中获得真正的性能收益-考虑到输入数组的构造方式,我们需要防止对这些大容量数组进行酸洗,并将生成数组的责任转移到 target 函数,只需传入数组 size / shape 选项:
import numpy as np
from multiprocessing import Pool
import time
def cal_feature(start, end, size):
return np.mean(np.random.randint(start, end, size=size), axis=-1)
# multiprocessing
if __name__ == '__main__':
st = time.time()
data_shapes = [(0, 100, (300, 100, 2000)),
(100, 200, (300, 100, 2000)),
(100, 200, (300, 100, 2000)),
(100, 200, (300, 100, 2000)),
(100, 200, (300, 100, 2000))]
with Pool(10) as p:
result = p.starmap(cal_feature, data_shapes)
print('multprocessing total time', time.time() - st)
时间执行结果:multprocessing total time 0.6732537746429443
(连续处理3秒)
pool.starmap
将返回所需的累积结果。