我为异步/协程和同步函数创建了以下装饰器。
def authorize(role):
def decorator(f):
@contextmanager
def auth(): # Business logic shared by async and sync funtions
if is_authorized(role): # role admin returns True, otherwise False
yield
else:
print('ERROR')
@wraps(f)
def wrapper(*args, **kwargs):
if not asyncio.iscoroutinefunction(f):
with auth():
return f(*args, **kwargs)
else:
async def tmp():
with auth():
return (await f(*args, **kwargs))
return tmp()
return wrapper
return decorator
如果 is_authorized()
返回 True,效果很好。
@authorize(role='Readonly')
def test():
print('TEST')
test()
但是,当 is_authorized()
返回 False 时会引发异常。如果未授权,则不应调用装饰函数,应返回 501 HTTP 错误。
@authorize(role='Readonly')
def test():
print('TEST')
ERROR Traceback (most recent call last): File "", line 1, in File "", line 13, in wrapper File "C:\anaconda3\lib\contextlib.py", line 115, in __enter__ raise RuntimeError("generator didn't yield") from None RuntimeError: generator didn't yield
答案 0 :(得分:0)
出现错误是因为 contextmanager
必须是生成器,这意味着它必须始终执行 yield
语句,因为 yield
将 __enter__
和 __exit__
部分分开上下文管理器。在您的实现中,仅当 is_autorized
返回 True
时才会产生。
实际上您在这里不需要 contexmanager
,您需要简单的 if
语句。
我通过参数传递 is_authorized
,因为它对于为测试或其他目的注入替代实现很有用。
import asyncio
import functools
def authorize(role, is_authorized):
def decorator(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
if is_authorized(role):
if asyncio.iscoroutinefunction(f):
async def tmp():
return (await f(*args, **kwargs))
return tmp()
else:
return f(*args, **kwargs)
elif asyncio.iscoroutinefunction(f):
# must return coro anyway
async def tmp():
print("async unauthorized")
return None
return tmp()
else:
print("sync unauthorized")
return None
return wrapper
return decorator
def is_authorized(role):
return role == "lucky"
@authorize("lucky", is_authorized)
async def func1():
await asyncio.sleep(0)
return "coro finished"
@authorize("whatever", is_authorized)
async def func2():
await asyncio.sleep(0)
return "coro would not called"
@authorize("lucky", is_authorized)
def func3():
return "sync func finished"
@authorize("whatever", is_authorized)
def func4():
return "would not called"
if __name__ == "__main__":
print(asyncio.run(func1()))
print(asyncio.run(func2()))
print(func3())
print(func4())
印刷品
coro finished
async unauthorized
None
sync func finished
sync unauthorized
None