python多处理模块

时间:2017-05-25 20:38:27

标签: python multithreading multiprocessing python-multiprocessing

我正在使用机器学习算法(SVM)对卫星图像进行分类。我使用的图像是7GB。所以,我需要使用multiprocessing python模块来加快计算时间。我已阅读Stack Overflow中的所有帖子以及multiprocessing模块的文档,了解如何使用它。我发现使用pool方法后我的代码变得很慢。显然,我做错了什么,我不知道它是什么。下面,我使用了multiprocessing

img = '/home/SVM_seaIce_types/subset_of_subset_calibration_201605_Polar_stereographic_ratio.tif' 

def predict():

new_shape = (img.shape[0] * img.shape[1], img.shape[2]-2 )

#reshape the image from 3d to 2d in order to use it for classification
img_as_array = img[:, :, 2:].reshape(new_shape)
print('Reshaped from {o} to {n}'.format(o=img.shape,
                                        n=img_as_array.shape))

# Now predict for each pixel
class_prediction = svm.predict(img_as_array)

# Reshape the image and produce the classification map
class_prediction = class_prediction.reshape(img[:, :, 0].shape)
return class_prediction

if __name__ == '__main__':
    start = time.time()

    pool = Pool(processes=5)
    result = pool.apply_async(predict) 
    print result.get()

    end = time.time()
    print 'the processing time is',(end - start)

enter image description here

以下是不使用多处理的代码

img = '/home/SVM_seaIce_types/subset_of_subset_calibration_201605_Polar_stereographic_ratio.tif' 
start = time.time()

new_shape = (img.shape[0] * img.shape[1], img.shape[2]-2 )

img_as_array = img[:, :, 2:].reshape(new_shape)
print('Reshaped from {o} to {n}'.format(o=img.shape,
                                    n=img_as_array.shape))

# Now predict for each pixel
class_prediction = svm.predict(img_as_array)

# Reshape our classification map
class_prediction = class_prediction.reshape(img[:, :, 0].shape)

print class_prediction

end = time.time()
print 'the processing time is:', end - start

enter image description here

通过查看处理时间,我们看到了很大的差异。我不明白发生了什么。我可能还不太了解multiprocessing模块是如何工作的。这就是我需要你帮助的原因。

顺便说一下,在我的模型预测之后,填充数字1的数组对应于第1类。

在收到一些有用的反馈后,我正在对我的帖子进行一些编辑:

感谢大家对我的问题的回复。我遵循你所做的表扬并相应地更改了我的代码。现在,我可以在正确使用多处理模块后看到我的代码运行得更快。现在的问题是我没有得到我期待的结果。这是我的代码

img = '/home/john/desktop/seaIce.tif' #shape of image: 500 x 500 x 5    (row,cols,bands)

#reshape image 3d(rows x columns x number of bands) into 2d (total size  x    number of bands)
image = (img.shape[0] * img.shape[1], img.shape[2]-2 )
tfs2d = img[:, :, 2:].reshape(image)

#find data without nan values 
gpi = np.isfinite(tfs2d.sum(axis=0))
tfsgood = tfs2d[gpi, :]

#use svm model to clasify our image
def predict_class(input_data):
    prediction = svm.predict(input_data)
    return prediction


def main():
    #slice the image into chunks
    chunk_size=100    
    chunks = [tfsgood[i:i+chunk_size, :]
      for i in xrange(0, tfsgood.shape[0], chunk_size)]

    #use multiprocessing module
    pool = Pool(6)
    svm_labelsgood = pool.map(predict_class, chunks)

    # join the results
    svm_labelsgood = np.dstack(svm_labelsgood)
    svmlabelsall = np.zeros(tfs2d.shape[0])
    svmlabelsall[gpi] = svm_labelsgood

    #reshape the image so we can display it with matplotlib
    reshape = svmlabelsall.reshape(img.shape[0], img.shape[1])
    print img.shape
    print reshape.shape


    plt.imshow(reshape)
    plt.show()


if __name__ == '__main__':
    main()

我得到的分类图片是:

enter image description here

而不是得到这个:

enter image description here

底部的图像是由3个类别(海冰类型)组成的最终分类图像。问题是我不能等待10个小时才能得到分类结果。应用multiprocessing模块时我做错了什么?我盯着我的代码几个小时,无法理解为什么我会得到一个蓝色图像。

0 个答案:

没有答案