我可以获得并行化此代码的帮助吗?我正在将多标签分类问题转换为OneVsRest(二元相关)问题。由于提到here的内存问题,我是手动完成的。
clf_label = {}
for i, label in enumerate(label_index.keys()):
print 'Fitting', i, 'label out of', len(label_index)
clf = SGDClassifier(loss='hinge', shuffle=True, alpha=0.000001, verbose=0, n_iter=5, n_jobs=4)
temp_y = np.zeros(trainY.shape)
temp_y[label_index[label]] = 1
clf.fit(trainX, temp_y)
clf_label[label] = clf
我循环遍历keys
label_index
并为每个标签构建分类器。在每个分类器都适合之后,我将其保存到另一个dict
中,其中键又是标签,但值是分类器。由于运行时间长,我想并行化这段代码。以下是multiprocessing's
Pool.map
的尝试:
def fit_label(label, trainX, trainY, label_index):
# print 'Fitting', i, 'label out of', len(label_index)
clf = SGDClassifier(loss='hinge', shuffle=True, alpha=0.000001, verbose=0, n_iter=5)
temp_y = np.zeros(trainY.shape)
temp_y[label_index[label]] = 1
clf.fit(trainX, temp_y)
return clf
def linear_svm():
p = Pool(2)
func = partial(fit_label, trainX=trainX, trainY=trainY, label_index=label_index)
res = p.map(func, label_index.keys()[1:6])
clf_label = dict(zip(label_index.keys()[1:6], res))
我收到此错误:
Exception in thread Thread-3:
Traceback (most recent call last):
File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/threading.py", line 808, in __bootstrap_inner
self.run()
File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/threading.py", line 761, in run
self.__target(*self.__args, **self.__kwargs)
File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/multiprocessing/pool.py", line 342, in _handle_tasks
put(task)
SystemError: NULL result without error in PyObject_Call
对于知道如何在Python中进行并行编程的人来说,这似乎是一项相当容易的任务,所以如果有人可以并行重写这一点而不是修改我的(狡猾的)代码,我真的很感激。谢谢。
答案 0 :(得分:1)
尝试将函数定义为在函数linear_svm()
之外并行化,如下所示:
def func(fit_label, trainX=None, trainY=None, label_index=None):
return partial(fit_label, trainX=trainX, trainY=trainY, label_index=label_index)
def linear_svm():
numProcessors = multiprocessing.cpu_count()
p = Pool(processes=numProcessors)
res = p.map_async(func, label_index.keys()[1:6])
poolres = res.get()
clf_label = dict(zip(label_index.keys()[1:6], poolres))