所有异步、协程和同步函数的装饰器?

时间:2021-03-18 06:24:42

标签: python

我为异步/协程和同步函数创建了以下装饰器。

def authorize(role):
    def decorator(f):

        @contextmanager
        def auth(): # Business logic shared by async and sync funtions
            if is_authorized(role): # role admin returns True, otherwise False
                yield
            else:
                print('ERROR')
                        
        @wraps(f)
        def wrapper(*args, **kwargs):
            if not asyncio.iscoroutinefunction(f):
                with auth():
                    return f(*args, **kwargs)
            else:
                async def tmp():
                    with auth():
                        return (await f(*args, **kwargs))
                return tmp()
        return wrapper
    return decorator

如果 is_authorized() 返回 True,效果很好。

@authorize(role='Readonly')
def test():
    print('TEST')

test()
    

但是,当 is_authorized() 返回 False 时会引发异常。如果未授权,则不应调用装饰函数,应返回 501 HTTP 错误。

@authorize(role='Readonly')
def test():
    print('TEST')
    
ERROR
Traceback (most recent call last):
  File "", line 1, in 
  File "", line 13, in wrapper
  File "C:\anaconda3\lib\contextlib.py", line 115, in __enter__
    raise RuntimeError("generator didn't yield") from None
RuntimeError: generator didn't yield

1 个答案:

答案 0 :(得分:0)

出现错误是因为 contextmanager 必须是生成器,这意味着它必须始终执行 yield 语句,因为 yield__enter____exit__ 部分分开上下文管理器。在您的实现中,仅当 is_autorized 返回 True 时才会产生。

实际上您在这里不需要 contexmanager,您需要简单的 if 语句。

我通过参数传递 is_authorized,因为它对于为测试或其他目的注入替代实现很有用。

import asyncio
import functools


def authorize(role, is_authorized):
    def decorator(f):
        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            if is_authorized(role):
                if asyncio.iscoroutinefunction(f):
                    async def tmp():
                        return (await f(*args, **kwargs))
                    return tmp()
                else:
                    return f(*args, **kwargs)
            elif asyncio.iscoroutinefunction(f):
                # must return coro anyway
                async def tmp():
                    print("async unauthorized")
                    return None
                return tmp()
            else:
                print("sync unauthorized")
                return None
        return wrapper
    return decorator


def is_authorized(role):
    return role == "lucky"


@authorize("lucky", is_authorized)
async def func1():
    await asyncio.sleep(0)
    return "coro finished"

@authorize("whatever", is_authorized)
async def func2():
    await asyncio.sleep(0)
    return "coro would not called"

@authorize("lucky", is_authorized)
def func3():
    return "sync func finished"

@authorize("whatever", is_authorized)
def func4():
    return "would not called"


if __name__ == "__main__":
    print(asyncio.run(func1()))
    print(asyncio.run(func2()))
    print(func3())
    print(func4())

印刷品

coro finished

async unauthorized
None

sync func finished

sync unauthorized
None
相关问题