如何加快python执行速度?多重处理无效

时间:2019-08-24 19:39:43

标签: python multiprocessing

我正在尝试进行多处理以加快特征提取过程。 这是我所做的:

import numpy as np
from multiprocessing import Pool, Process
import time


def cal_feature(subject):
    return (np.mean((subject),axis=-1))

#multiprocessing
if __name__ == '__main__':
    st=time.time()
    data_1= np.random.randint(0, 100, size=(300, 100, 2000))
    data_2= np.random.randint(100, 200, size=(300, 100, 2000))
    data_3= np.random.randint(100, 200, size=(300, 100, 2000))
    data_4= np.random.randint(100, 200, size=(300, 100, 2000))
    data_5= np.random.randint(100, 200, size=(300, 100, 2000))
    data={1:data_1,2:data_2,3:data_3,4:data_4,5:data_5}
    p=Pool(10)
    parallel_result=[]
    for i in data.keys():
        result=p.map(cal_feature, np.split(data[i], 10))
        parallel=np.concatenate((result),axis=0)
        parallel_result.append(parallel)
    p.close()
    p.join()
    print('multprocessing total time',time.time()-st)

#Serial processing
st=time.time()
data_1= np.random.randint(0, 100, size=(300, 100, 2000))
data_2= np.random.randint(100, 200, size=(300, 100, 2000))
data_3= np.random.randint(100, 200, size=(300, 100, 2000))
data_4= np.random.randint(100, 200, size=(300, 100, 2000))
data={1:data_1,2:data_2,3:data_3,4:data_4,5:data_5}
series_result=[]
series=[]
for i in data.keys():
    series_result.append(cal_feature(data[i]))

print('series toal time',time.time()-st)

但是,多处理所花的时间比串行编程要长5倍。如何加快特征提取速度?在这里,我使用numpy.mean作为功能,但是在实际数据集中,我有30个复杂的功能。我有80个这样的数据集,而不是随机生成的5个数据集。有没有办法进行健壮的特征提取?

1 个答案:

答案 0 :(得分:0)

重点是每个进程必须为每个 target 处理腌制相当大的数组-引起了轰动。而连续处理直接处理初始数组。有了您的输入数据,我就连续3秒,多处理方法6秒。

您尝试执行的问题:不要在每次迭代中都初始化pool.map-进行一次map调用,因为它已经暗示要处理可迭代的数据项。并且不需要拆分数据项np.split(data[i], 10),而是将data[i]传递给目标函数。

话虽如此,要在您的情况下从multiprocessing方法中获得真正的性能收益-考虑到输入数组的构造方式,我们需要防止对这些大容量数组进行酸洗,并将生成数组的责任转移到 target 函数,只需传入数组 size / shape 选项:

import numpy as np
from multiprocessing import Pool
import time


def cal_feature(start, end, size):
    return np.mean(np.random.randint(start, end, size=size), axis=-1)


# multiprocessing
if __name__ == '__main__':
    st = time.time()
    data_shapes = [(0, 100, (300, 100, 2000)),
                   (100, 200, (300, 100, 2000)),
                   (100, 200, (300, 100, 2000)),
                   (100, 200, (300, 100, 2000)),
                   (100, 200, (300, 100, 2000))]
    with Pool(10) as p:
        result = p.starmap(cal_feature, data_shapes)
        print('multprocessing total time', time.time() - st)

时间执行结果:multprocessing total time 0.6732537746429443(连续处理3秒)

pool.starmap将返回所需的累积结果。