我有一个简单的银行帐户类别,需要通过以下测试
import sys
import threading
import time
import unittest
from bank_account import BankAccount
class BankAccountTest(unittest.TestCase):
def test_can_handle_concurrent_transactions(self):
account = BankAccount()
account.open()
account.deposit(1000)
self.adjust_balance_concurrently(account)
self.assertEqual(account.get_balance(), 1000)
def adjust_balance_concurrently(self, account):
def transact():
account.deposit(5)
time.sleep(0.001)
account.withdraw(5)
# Greatly improve the chance of an operation being interrupted
# by thread switch, thus testing synchronization effectively
try:
sys.setswitchinterval(1e-12)
except AttributeError:
# For Python 2 compatibility
sys.setcheckinterval(1)
threads = [threading.Thread(target=transact) for _ in range(1000)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
我尝试阅读线程文档,但是尝试将其应用于我的情况时有点困惑。我试过的是这个
class BankAccount(Thread):
def __init__(self):
Thread.__init__(self)
self.state = False
self.balance = 0
Thread().start()
def get_balance(self):
if self.state:
return self.balance
else:
raise ValueError
def open(self):
self.state = True
def deposit(self, amount):
self.balance += amount
def withdraw(self, amount):
self.balance -= amount
这显然是错误的。我的目的只是了解如何使该类处理线程切换。如果我没有提供重要信息,请告诉我。
答案 0 :(得分:0)
我们需要确保多个线程不能同时修改余额。
deposit()
函数虽然看起来只是一步,但却是多步操作。
old_balance = self.balance
new_balance = old_balance + deposit
self.balance = new_balance
如果线程切换发生在存款中间,则可能会破坏数据。
例如,假设线程1调用deposit(10)
,线程2调用deposit(20)
,则初始余额为100
# Inside thread 1
old_balance1 = self.balance
new_balance1 = old_balance1 + 10
# Thread switches to thread 2
old_balance2 = self.balance
new_balance2 = old_balance2 + 20
self.balance = new_balance2 # balance = 120
# Thread switches back to thread 1
self.balance = new_balance1 # balance = 110
这里的最终余额为110
,而应该为130
。
解决方案是防止两个线程同时写入balance
变量。我们可以利用Lock
来完成此操作。
import threading
class BankAccount:
def open(self):
self.balance = 0
# initialize lock
self.lock = threading.Lock()
def deposit(self, val):
# if another thread has acquired lock, block till it releases
self.lock.acquire()
self.balance += val
self.lock.release()
def withdraw(self, val):
self.lock.acquire()
self.balance -= val
self.lock.release()
def get_balance(self):
return self.balance