如何使用外部模块

时间:2019-01-25 16:04:11

标签: scikit-learn joblib

我正在尝试修改sklearn源代码。特别是,我正在修改GridSearch源代码,以使评估不同模型配置的单独进程/线程在它们之间共享一个变量。我需要每个线程/进程在运行时读取/更新该变量,以便根据获得的其他线程来修改其执行。更具体而言,我要共享的参数是以下示例中的最佳

out = parallel(delayed(_fit_and_score)(clone(base_estimator), X, y, best, self.early,train=train, test=test,parameters=parameters,**fit_and_score_kwargs) for parameters, (train, test) in product(candidate_params, cv.split(X, y, groups))) 

注意 _fit_and_score 函数位于单独的模块中。 Sklearn利用joblib进行并行化,但是我无法理解如何使用外部模块有效地做到这一点。在joblib文档中提供了以下代码:

>>> shared_set = set()
>>> def collect(x):
...    shared_set.add(x)
...
>>> Parallel(n_jobs=2, require='sharedmem')(
...     delayed(collect)(i) for i in range(5))
[None, None, None, None, None]
>>> sorted(shared_set)
[0, 1, 2, 3, 4]

但是我无法理解如何使其在我的上下文中运行。您可以在此处找到源代码:

1 个答案:

答案 0 :(得分:0)

您可以使用python的管理器(https://docs.python.org/3/library/multiprocessing.html#multiprocessing.sharedctypes.multiprocessing.Manager)进行操作,例如简单的代码:

from joblib import Parallel, delayed
from multiprocessing import Manager

manager = Manager()
q = manager.Namespace()
q.flag = False

def test(i, q):
    #update shared var in 0 process
    if i == 0:
        q.flag = True

    # do nothing for few seconds
    for n in range(100000000):
        if q.flag == True:
            return f'process {i} was updated'

    return 'process {i} was not updated'

out = Parallel(n_jobs=4)(delayed(test)(i, q) for i in range(4))

退出:

['process 0 was updated',
 'process 1 was updated',
 'process 2 was updated',
 'process 3 was updated']