假设我已经有一个带有类型注释的方法:
class Shape:
def area(self) -> float:
raise NotImplementedError
然后我会多次进行子类化:
class Circle:
def area(self) -> float:
return math.pi * self.radius ** 2
class Rectangle:
def area(self) -> float:
return self.height * self.width
正如您所看到的,我正在重复-> float
。假设我有10种不同的形状,有这样的多种方法,其中一些也含有参数。有没有办法从父类中“复制”注释,类似于functools.wraps()
对文档字符串的处理方式?
答案 0 :(得分:4)
这可能有用,但我肯定会错过边缘情况,比如其他参数:
from functools import partial, update_wrapper
def annotate_from(f):
return partial(update_wrapper,
wrapped=f,
assigned=('__annotations__',),
updated=())
将分配"包装"函数来自__annotations__
的{{1}}属性(请记住,它不是副本)。
根据文档,已分配的update_wrapper
功能默认包含f.__annotations__
,但我可以理解为什么您不想拥有所有内容从包装分配的其他属性。
然后,您可以将__annotations__
和Circle
定义为
Rectangle
和结果
class Circle:
@annotate_from(Shape.area)
def area(self):
return math.pi * self.radius ** 2
class Rectangle:
@annotate_from(Shape.area)
def area(self):
return self.height * self.width
作为副作用,您的方法会有一个属性In [82]: Circle.area.__annotations__
Out[82]: {'return': builtins.float}
In [86]: Rectangle.area.__annotations__
Out[86]: {'return': builtins.float}
,在这种情况下会指向__wrapped__
。
使用类装饰器可以实现一个不太标准的(如果你可以调用上面使用的 update_wrapper 标准)方法来完成对重写方法的处理:
Shape.area
然后:
from inspect import getmembers, isfunction, signature
def override(f):
"""
Mark method overrides.
"""
f.__override__ = True
return f
def _is_method_override(m):
return isfunction(m) and getattr(m, '__override__', False)
def annotate_overrides(cls):
"""
Copy annotations of overridden methods.
"""
bases = cls.mro()[1:]
for name, method in getmembers(cls, _is_method_override):
for base in bases:
if hasattr(base, name):
break
else:
raise RuntimeError(
'method {!r} not found in bases of {!r}'.format(
name, cls))
base_method = getattr(base, name)
method.__annotations__ = base_method.__annotations__.copy()
return cls
同样,这不会处理带有其他参数的覆盖方法。
答案 1 :(得分:0)
您可以使用类装饰器来更新子类方法注释。在装饰器中,您需要遍历类定义,然后仅更新超类中存在的那些方法。当然要访问超类,你需要使用它__mro__
,它只是类,子类的元组,直到object
。这里我们感兴趣的是该元组中的第二个元素,它位于索引1
,因此__mro__[1]
或使用cls.mro()[1]
。最后也是最重要的是,你的装饰师必须返回班级。
def wraps_annotations(cls):
mro = cls.mro()[1]
vars_mro = vars(mro)
for name, value in vars(cls).items():
if callable(value) and name in vars_mro:
value.__annotations__.update(vars(mro).get(name).__annotations__)
return cls
演示:
>>> class Shape:
... def area(self) -> float:
... raise NotImplementedError
...
>>> import math
>>>
>>> @wraps_annotations
... class Circle(Shape):
... def area(self):
... return math.pi * self.radius ** 2
...
>>> c = Circle()
>>> c.area.__annotations__
{'return': <class 'float'>}
>>> @wraps_annotations
... class Rectangle(Shape):
... def area(self):
... return self.height * self.width
...
>>> r = Rectangle()
>>> r.area.__annotations__
{'return': <class 'float'>}