Python 3 concurrent.futures和每线程初始化

时间:2016-05-04 03:43:15

标签: multithreading python-3.x concurrent.futures

在Python 3中,是否可以在Thread的上下文中使用concurrent.futures.ThreadPoolExecutor的子类,以便在处理(可能是很多)工作项之前可以单独初始化它们?

我想使用方便的concurrent.futures API来处理同步文件和S3对象的代码(如果相应的S3对象不存在或不在外,每个工作项都是一个要同步的文件的同步)。我希望每个工作线程首先进行一些初始化,例如设置boto3.session.Session。然后,该工作线程池将准备好处理可能数千个工作项(要同步的文件)。

顺便说一句,如果一个线程由于某种原因而死亡,是否有理由期望自动创建一个新线程并将其添加回池中?

(免责声明:我比Java更熟悉Java的多线程框架)。

1 个答案:

答案 0 :(得分:2)

因此,似乎我的问题的一个简单解决方案是使用threading.local来存储每个线程的“会话”(在下面的模型中,只是一个随机的int)。也许不是我想的最干净,但现在它会做。这是一个模型(Python 3.5.1):

import time
import threading
import concurrent.futures
import random
import logging

logging.basicConfig(level=logging.DEBUG, format='(%(threadName)-0s) %(relativeCreated)d - %(message)s')

x = [0.1, 0.1, 0.2, 0.4, 1.0, 0.1, 0.0]

mydata = threading.local()

def do_work(secs):
    if 'session' in mydata.__dict__:
        logging.debug('re-using session "{}"'.format(mydata.session))
    else:
        mydata.session = random.randint(0,1000)
        logging.debug('created new session: "{}"'.format(mydata.session))
    time.sleep(secs)
    logging.debug('slept for {} seconds'.format(secs))
    return secs

with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
    y = executor.map(do_work, x)

print(list(y))

生成以下输出,显示“sessions”确实是每个线程的本地并重用:

(Thread-1) 29 - created new session: "855"
(Thread-2) 29 - created new session: "58"
(Thread-3) 30 - created new session: "210"
(Thread-1) 129 - slept for 0.1 seconds
(Thread-1) 130 - re-using session "855"
(Thread-2) 130 - slept for 0.1 seconds
(Thread-2) 130 - re-using session "58"
(Thread-3) 230 - slept for 0.2 seconds
(Thread-3) 230 - re-using session "210"
(Thread-3) 331 - slept for 0.1 seconds
(Thread-3) 331 - re-using session "210"
(Thread-3) 331 - slept for 0.0 seconds
(Thread-1) 530 - slept for 0.4 seconds
(Thread-2) 1131 - slept for 1.0 seconds
[0.1, 0.1, 0.2, 0.4, 1.0, 0.1, 0.0]

关于日志记录的一些注意事项:为了在IPython笔记本中使用它,需要稍微修改日志记录设置(因为IPython已经设置了根记录器)。更强大的日志记录设置将是:

IN_IPYNB = 'get_ipython' in vars()

if IN_IPYNB:
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    for h in logger.handlers:
        h.setFormatter(logging.Formatter(
                '(%(threadName)-0s) %(relativeCreated)d - %(message)s'))
else:
    logging.basicConfig(level=logging.DEBUG, format='(%(threadName)-0s) %(relativeCreated)d - %(message)s')