我有一个在Django服务上运行的大型python应用程序。我需要关闭某些操作的权限测试,所以我创建了这个上下文管理器:
class OverrideTests(object):
def __init__(self):
self.override = 0
def __enter__(self):
self.override += 1
# noinspection PyUnusedLocal
def __exit__(self, exc_type, exc_val, exc_tb):
self.override -= 1
assert not self.override < 0
@property
def overriding(self):
return self.override > 0
override_tests = OverrideTests()
然后,应用程序的各个部分可以使用上下文管理器覆盖测试:
with override_tests:
do stuff
...
在do stuff中,上述上下文管理器可以在不同的功能中多次使用。计数器的使用使这一点得到控制,它似乎工作得很好......直到线程介入。
一旦涉及到线程,全局上下文管理器就会被重用,因此测试可能会被错误地覆盖。
这是一个简单的测试用例 - 如果thread.start_new_thread(do_id, ())
行被简单的do_it
替换,但如果显示失败,则可以正常工作:
def stat(k, expected):
x = '.' if override_tests.overriding == expected else '*'
sys.stdout.write('{0}{1}'.format(k, x))
def do_it_inner():
with override_tests:
stat(2, True)
stat(3, True) # outer with context makes this true
def do_it():
with override_tests:
stat(1, True)
do_it_inner()
stat(4, False)
def do_it_lots(ntimes=10):
for i in range(ntimes):
thread.start_new_thread(do_it, ())
如何使这个上下文管理器线程安全,以便在每个Python线程中,即使它是可重入的,它仍然被一直使用?
答案 0 :(得分:5)
以下是似乎的工作方式:使您的OverrideTests类成为threading.local
的子类。为安全起见,您应该调用__init__
中的超类__init__
(尽管即使您不这样做,它似乎仍然有效):
class OverrideTests(threading.local):
def __init__(self):
super(OverrideTests, self).__init__()
self.override = 0
# rest of class same as before
override_tests = OverrideTests()
然后:
>>> do_it_lots()
1.1.1.2.2.1.1.1.1.1.1.3.3.2.2.2.2.2.2.4.4.3.1.3.3.3.3.4.3.2.4.4.2.4.3.4.4.4.3.4.
但是,我不会在这种情况下投入资金,尤其是如果你的真实应用程序比你在这里展示的例子更复杂的话。最终,你真的应该重新考虑你的设计。在您的问题中,您正在关注如何“使上下文管理器线程安全”。但真正的问题不仅仅在于您的上下文管理器,而在于您的功能(在您的示例中为stat
)。 stat
依赖于全局状态(全局override_tests
),它在线程环境中本质上是脆弱的。
答案 1 :(得分:-1)
threading.RLock
是一个可重入的锁,可以通过同一个线程多次获取。它还支持上下文管理协议,因此可以与with
语句一起使用。
它有一个owner字段,表示当前持有锁的线程。私有方法_is_owned
告诉调用线程是否拥有锁。所有者值可用于确定当前线程是否保持锁定,这使得实现简单。
不需要计数器和线程本地存储。如果当前线程不是所有者,则表示当前线程未对该锁定进行操作,因此它不会覆盖。
import sys
from threading import RLock
try:
import _thread as thread
except ImportError:
import thread
from time import sleep
class OverrideTests(type(RLock())):
@property
def overriding(self):
return self._is_owned()
override_tests = OverrideTests()
def stat(k, expected):
x = '.' if override_tests.overriding == expected else '*'
sys.stdout.write('{0}{1}'.format(k, x))
sys.stdout.flush()
def do_it_inner():
with override_tests:
stat(2, True)
stat(3, True) # outer with context makes this true
def do_it():
with override_tests:
stat(1, True)
do_it_inner()
stat(4, False)
def do_it_lots(ntimes=10):
for _ in range(ntimes):
thread.start_new_thread(do_it,())
random_sleep()
if __name__ == '__main__':
do_it_lots()
sleep(2)