使用Python的asyncio实现时间锁

时间:2019-06-05 19:17:25

标签: python locking python-asyncio

我不知道这种锁定是否称为时间锁定,但是在以下情况下我需要一些东西:我正在与for进行大量并发请求,并且服务器可能在某个位置点返回while(n-- > 0) { ... }。在那种情况下,我必须暂停所有后续请求一段时间。

我想出了以下解决方案:

aiohttp

我正在用以下代码测试锁:

429 Too Many Requests

代码产生以下输出,这似乎与我在上述情况下想要的一致:

import asyncio


class TimeLock:

    def __init__(self, *, loop=None):
        self._locked = False
        self._locked_at = None
        self._time_lock = None
        self._unlock_task = None
        self._num_waiters = 0
        if loop is not None:
            self._loop = loop
        else:
            self._loop = asyncio.get_event_loop()

    def __repr__(self):
        state = f'locked at {self.locked_at}' if self._locked else 'unlocked'
        return f'[{state}] {self._num_waiters} waiters'

    @property
    def locked(self):
        return self._locked

    @property
    def locked_at(self):
        return self._locked_at

    async def __aenter__(self):
        await self.acquire()
        return self

    async def __aexit__(self, exc_type, exc, tb):
        # in this time lock there is nothing to do when it's released
        return

    async def acquire(self):
        if not self._locked:
            return True
        try:
            print('waiting for lock to be released')
            self._num_waiters += 1
            await self._time_lock
            self._num_waiters -= 1
            print('done, returning now')
        except asyncio.CancelledError:
            if self._locked:
                raise
        return True

    def lock_for(self, delay, lock_more=False):
        print(f'locking for {delay}')
        if self._locked:
            if not lock_more:
                # if we don't want to increase the lock time, we just exit when
                # the lock is already in a locked state
                print('already locked, nothing to do')
                return
            print('already locked, but canceling old unlock task')
            self._unlock_task.cancel()
        self._locked = True
        self._locked_at = time.time()
        self._time_lock = self._loop.create_future()
        self._unlock_task = self._loop.create_task(self.unlock_in(delay))
        print('locked')

    async def unlock_in(self, delay):
        print('unlocking started')
        await asyncio.sleep(delay)
        self._locked = False
        self._locked_at = None
        self._unlock_task = None
        self._time_lock.set_result(True)
        print('unlocked')

这是实现此同步原语的正确方法吗? 我也不确定此代码的线程安全性。我对线程和异步代码没有太多经验。

1 个答案:

答案 0 :(得分:1)

我没有测试您的代码,但是想法似乎不错。仅当要在不同线程中使用同一锁对象时,才应担心线程安全。正如吉米·恩格尔布雷希特(Jimmy Engelbrecht)指出的那样,异步运行于单线程中,您通常不必担心基元的线程安全性。

这里还有一些想法:

  • 我注意到有关术语的肯定,但是似乎该原语应该称为semaphore
  • 您可以继承或直接使用existing primitive(s)
  • 而不是从乞求中实现。
  • 您可以委托何时对事件进行信号量跟踪,而不是在客户端代码中进行

此代码段显示了这个想法:

import asyncio


class PausingSemaphore:
    def __init__(self, should_pause, pause_for_seconds):
        self.should_pause = should_pause
        self.pause_for_seconds = pause_for_seconds
        self._is_paused = False
        self._resume = asyncio.Event()

    async def __aenter__(self):
        await self.check_paused()
        return self

    async def __aexit__(self, exc_type, exc, tb):
        if self.should_pause(exc):
            self.pause()

    async def check_paused(self):
        if self._is_paused:
            await self._resume.wait()

    def pause(self):
        if not self._is_paused:
            self._is_paused = True
            asyncio.get_running_loop().call_later(
                self.pause_for_seconds,
                self.unpause
            )

    def unpause(self):
        self._is_paused = False
        self._resume.set()

让我们测试一下:

import aiohttp


def should_pause(exc):
    return (
        type(exc) is aiohttp.ClientResponseError 
        and
        exc.status == 429
    )


pausing_sem = None
regular_sem = None


async def request(url):
    async with regular_sem:
        async with pausing_sem:
            try:
                async with aiohttp.ClientSession() as session:
                    async with session.get(url, raise_for_status=True) as resp:
                        print('Done!')
            except aiohttp.ClientResponseError:
                print('Too many requests!')
                raise


async def main():
    global pausing_sem
    global regular_sem
    pausing_sem = PausingSemaphore(should_pause, 5)
    regular_sem = asyncio.Semaphore(3)

    await asyncio.gather(
        *[
            request('http://httpbin.org/get'),
            request('http://httpbin.org/get'),
            request('http://httpbin.org/get'),
            request('http://httpbin.org/get'),
            request('http://httpbin.org/get'),
            request('http://httpbin.org/status/429'),
            request('http://httpbin.org/get'),
            request('http://httpbin.org/get'),
            request('http://httpbin.org/get'),
            request('http://httpbin.org/get'),
            request('http://httpbin.org/get'),
        ], 
        return_exceptions=True
    )


if __name__ == '__main__':
    asyncio.run(main())

P.S。没有太多测试此代码!