如何在子类中键入注释覆盖方法?

时间:2016-03-30 10:59:34

标签: python python-3.x types

假设我已经有一个带有类型注释的方法:

class Shape:
    def area(self) -> float:
        raise NotImplementedError

然后我会多次进行子类化:

class Circle:
    def area(self) -> float:
        return math.pi * self.radius ** 2

class Rectangle:
    def area(self) -> float:
        return self.height * self.width

正如您所看到的,我正在重复-> float。假设我有10种不同的形状,有这样的多种方法,其中一些也含有参数。有没有办法从父类中“复制”注释,类似于functools.wraps()对文档字符串的处理方式?

2 个答案:

答案 0 :(得分:4)

这可能有用,但我肯定会错过边缘情况,比如其他参数:

from functools import partial, update_wrapper


def annotate_from(f):
    return partial(update_wrapper,
                   wrapped=f,
                   assigned=('__annotations__',),
                   updated=())

将分配"包装"函数来自__annotations__的{​​{1}}属性(请记住,它不是副本)。

根据文档,已分配update_wrapper功能默认包含f.__annotations__,但我可以理解为什么您不想拥有所有内容从包装分配的其他属性。

然后,您可以将__annotations__Circle定义为

Rectangle

和结果

class Circle:
    @annotate_from(Shape.area)
    def area(self):
        return math.pi * self.radius ** 2

class Rectangle:
    @annotate_from(Shape.area)
    def area(self):
        return self.height * self.width

作为副作用,您的方法会有一个属性In [82]: Circle.area.__annotations__ Out[82]: {'return': builtins.float} In [86]: Rectangle.area.__annotations__ Out[86]: {'return': builtins.float} ,在这种情况下会指向__wrapped__

使用类装饰器可以实现一个不太标准的(如果你可以调用上面使用的 update_wrapper 标准)方法来完成对重写方法的处理:

Shape.area

然后:

from inspect import getmembers, isfunction, signature


def override(f):
    """
    Mark method overrides.
    """
    f.__override__ = True
    return f


def _is_method_override(m):
    return isfunction(m) and getattr(m, '__override__', False)


def annotate_overrides(cls):
    """
    Copy annotations of overridden methods.
    """
    bases = cls.mro()[1:]
    for name, method in getmembers(cls, _is_method_override):
        for base in bases:
            if hasattr(base, name):
                break

        else:
            raise RuntimeError(
                    'method {!r} not found in bases of {!r}'.format(
                            name, cls))

        base_method = getattr(base, name)
        method.__annotations__ = base_method.__annotations__.copy()

    return cls

同样,这不会处理带有其他参数的覆盖方法。

答案 1 :(得分:0)

您可以使用类装饰器来更新子类方法注释。在装饰器中,您需要遍历类定义,然后仅更新超类中存在的那些方法。当然要访问超类,你需要使用它__mro__,它只是类,子类的元组,直到object。这里我们感兴趣的是该元组中的第二个元素,它位于索引1,因此__mro__[1]或使用cls.mro()[1]。最后也是最重要的是,你的装饰师必须返回班级。

def wraps_annotations(cls):
    mro = cls.mro()[1] 
    vars_mro = vars(mro)
    for name, value in vars(cls).items():
        if callable(value) and name in vars_mro:
            value.__annotations__.update(vars(mro).get(name).__annotations__)
    return cls

演示:

>>> class Shape:
...     def area(self) -> float:
...         raise NotImplementedError
...
>>> import math
>>>
>>> @wraps_annotations
... class Circle(Shape):
...     def area(self):
...         return math.pi * self.radius ** 2
...
>>> c = Circle()
>>> c.area.__annotations__
{'return': <class 'float'>}
>>> @wraps_annotations
... class Rectangle(Shape):
...     def area(self):
...         return self.height * self.width
...
>>> r = Rectangle()
>>> r.area.__annotations__
{'return': <class 'float'>}