在类中使用装饰器检查管理员授权

时间:2020-06-28 02:00:26

标签: python python-3.x class oop decorator

在Python中,我创建了一个User类,该类可能具有两个UserTypeRegularAdmin中的一个。 User类具有多种方法,我希望其中的某些方法只能由管理员访问。

当前,我有以下代码:

from enum import Enum


class AuthorizationError(Exception):
    """Raised when a user attempts an admin-restricted task"""


class UserType(Enum):
    Regular = 0
    Admin = 1


class User:

    def __init__(self, username, user_type):
        self.username = username
        self.user_type = user_type

    def admin_required(self, func):
        def wrapper(*args, **kwargs):
            if self.user_type is UserType.Admin:
                return func(*args, **kwargs)
            else:
                raise AuthorizationError(f"User must be an admin to use {func.__name__}.")

        return wrapper

    def do_something_regular(self):
        print(f"{self.username} is doing something any regular user can do.")

    @admin_required
    def do_something_admin(self):
        print(f"{self.username} is doing something only an admin can do.")


me = User("MyUsername", UserType.Admin)
me.do_something_regular()
me.do_something_admin()

这会产生以下错误:

Traceback (most recent call last):
  File "example.py", line 13, in <module>
    class User:
  File "example.py", line 31, in User
    @admin_required
TypeError: admin_required() missing 1 required positional argument: 'func'

我了解我可能可以为管理员创建一个子类,但是目标是在User类中使用装饰器来检查管理员权限

我认为问题在于,当我包装do_something_admin函数时,do_something_admin传递给self参数,而不是self作为类的实例传递

我无法解决此问题。请记住,我想在解决方案中使用装饰器。谢谢!

2 个答案:

答案 0 :(得分:0)

仅在将方法绑定到类的实例时填充self参数。

在这种情况下,所引用的函数(装饰器)是未绑定的,因此需要提供第一个参数(self)。

如果您删除了self参数,那么它将起作用,例如:

class User:

    ...

    def admin_required(func):
        def wrapper(self, *args, **kwargs):
            if self.user_type is UserType.Admin:
                return func(self, *args, **kwargs)
            else:
                raise AuthorizationError(f"User must be an admin to use {func.__name__}.")

        return wrapper

   ...

答案 1 :(得分:0)

这是一个肮脏的把戏,但是它将允许您在方法上使用装饰器:

with open('file_from_python_text', 'w') as f: f.write('x\n')