我正在使用机器学习算法(SVM)对卫星图像进行分类。我使用的图像是7GB。所以,我需要使用multiprocessing
python模块来加快计算时间。我已阅读Stack Overflow
中的所有帖子以及multiprocessing
模块的文档,了解如何使用它。我发现使用pool
方法后我的代码变得很慢。显然,我做错了什么,我不知道它是什么。下面,我使用了multiprocessing
。
img = '/home/SVM_seaIce_types/subset_of_subset_calibration_201605_Polar_stereographic_ratio.tif'
def predict():
new_shape = (img.shape[0] * img.shape[1], img.shape[2]-2 )
#reshape the image from 3d to 2d in order to use it for classification
img_as_array = img[:, :, 2:].reshape(new_shape)
print('Reshaped from {o} to {n}'.format(o=img.shape,
n=img_as_array.shape))
# Now predict for each pixel
class_prediction = svm.predict(img_as_array)
# Reshape the image and produce the classification map
class_prediction = class_prediction.reshape(img[:, :, 0].shape)
return class_prediction
if __name__ == '__main__':
start = time.time()
pool = Pool(processes=5)
result = pool.apply_async(predict)
print result.get()
end = time.time()
print 'the processing time is',(end - start)
以下是不使用多处理的代码
img = '/home/SVM_seaIce_types/subset_of_subset_calibration_201605_Polar_stereographic_ratio.tif'
start = time.time()
new_shape = (img.shape[0] * img.shape[1], img.shape[2]-2 )
img_as_array = img[:, :, 2:].reshape(new_shape)
print('Reshaped from {o} to {n}'.format(o=img.shape,
n=img_as_array.shape))
# Now predict for each pixel
class_prediction = svm.predict(img_as_array)
# Reshape our classification map
class_prediction = class_prediction.reshape(img[:, :, 0].shape)
print class_prediction
end = time.time()
print 'the processing time is:', end - start
通过查看处理时间,我们看到了很大的差异。我不明白发生了什么。我可能还不太了解multiprocessing
模块是如何工作的。这就是我需要你帮助的原因。
顺便说一下,在我的模型预测之后,填充数字1的数组对应于第1类。
在收到一些有用的反馈后,我正在对我的帖子进行一些编辑:
感谢大家对我的问题的回复。我遵循你所做的表扬并相应地更改了我的代码。现在,我可以在正确使用多处理模块后看到我的代码运行得更快。现在的问题是我没有得到我期待的结果。这是我的代码
img = '/home/john/desktop/seaIce.tif' #shape of image: 500 x 500 x 5 (row,cols,bands)
#reshape image 3d(rows x columns x number of bands) into 2d (total size x number of bands)
image = (img.shape[0] * img.shape[1], img.shape[2]-2 )
tfs2d = img[:, :, 2:].reshape(image)
#find data without nan values
gpi = np.isfinite(tfs2d.sum(axis=0))
tfsgood = tfs2d[gpi, :]
#use svm model to clasify our image
def predict_class(input_data):
prediction = svm.predict(input_data)
return prediction
def main():
#slice the image into chunks
chunk_size=100
chunks = [tfsgood[i:i+chunk_size, :]
for i in xrange(0, tfsgood.shape[0], chunk_size)]
#use multiprocessing module
pool = Pool(6)
svm_labelsgood = pool.map(predict_class, chunks)
# join the results
svm_labelsgood = np.dstack(svm_labelsgood)
svmlabelsall = np.zeros(tfs2d.shape[0])
svmlabelsall[gpi] = svm_labelsgood
#reshape the image so we can display it with matplotlib
reshape = svmlabelsall.reshape(img.shape[0], img.shape[1])
print img.shape
print reshape.shape
plt.imshow(reshape)
plt.show()
if __name__ == '__main__':
main()
我得到的分类图片是:
而不是得到这个:
底部的图像是由3个类别(海冰类型)组成的最终分类图像。问题是我不能等待10个小时才能得到分类结果。应用multiprocessing
模块时我做错了什么?我盯着我的代码几个小时,无法理解为什么我会得到一个蓝色图像。