我是否必须重写子类中的所有数学运算符?

时间:2018-11-14 22:16:12

标签: python

我想在仅实现一些功能的Python 3.7程序中制作一个简单的Point2d类。我在一个SO答案(我现在找不到)中看到,创建Point类的一种方法是覆盖complex,所以我这样写:

import math

class Point2d(complex):

    def distanceTo(self, otherPoint):
        return math.sqrt((self.real - otherPoint.real)**2 + (self.imag - otherPoint.imag)**2)

    def x(self):
        return self.real

    def y(self):
        return self.imag

这有效:

In [48]: p1 = Point2d(3, 3)

In [49]: p2 = Point2d(6, 7)

In [50]: p1.distanceTo(p2)
Out[50]: 5.0

但是当我这样做时,p3complex的实例,而不是Point2d

In [51]: p3 = p1 + p2

In [52]: p3.distanceTo(p1)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-52-37fbadb3015e> in <module>
----> 1 p3.distanceTo(p1)

AttributeError: 'complex' object has no attribute 'distanceTo'

我的大部分背景是在Objective-C和C#中使用的,因此我仍在尝试找出执行此类操作的pythonic方法。我是否需要覆盖要在Point2d类上使用的所有数学运算符?还是我会完全以错误的方式来解决这个问题?

4 个答案:

答案 0 :(得分:1)

问题在于,您的类在使用属于复杂对象的任何数据模型函数时会返回复杂对象,因此您需要将其提交给您的Point2d

添加此方法应该可以解决问题

def __add__(self, b):
    return Point2d(super().__add__(b))

但是仍然应该有一个更好的方法。但这是动态包装某些“数据模型”(dunder)方法的方法。

顺便说一下,距离函数可以使它变得更短

def distanceTo(self, otherPoint):
    return abs(self - otherPoint)

答案 1 :(得分:1)

在这种情况下,我建议从头开始实现您的类Point2d

如果您很懒,请查看一些类似sympy的库,其中包括Point类和其他几何元素https://docs.sympy.org/latest/modules/geometry/index.html

答案 2 :(得分:1)

我将提到一种无需手动编写每个方法即可覆盖所有方法的方法,但这仅是因为我们全都consenting adults。我真的不建议这样做,如果您仅覆盖每个操作,就会更加清楚。就是说,您可以编写一个类包装器,以检查基类的所有方法,并将输出转换为点(如果它是复杂类型)。

import math
import inspect


def convert_to_base(cls):
    def decorate_func(name, method, base_cls):
        def method_wrapper(*args, **kwargs):
            obj = method(*args, **kwargs)
            return cls.convert(obj, base_cls) if isinstance(obj, base_cls) else obj
        return method_wrapper if name not in ('__init__', '__new__') else method
    for base_cls in cls.__bases__:
        for name, method in inspect.getmembers(base_cls, inspect.isroutine):  # Might want to change this filter
            setattr(cls, name, decorate_func(name, method, base_cls))
    return cls


@convert_to_base
class Point2d(complex):

    @classmethod
    def convert(cls, obj, base_cls):
        # Use base_cls if you need to know which base class to convert.
        return cls(obj.real, obj.imag)

    def distanceTo(self, otherPoint):
        return math.sqrt((self.real - otherPoint.real)**2 + (self.imag - otherPoint.imag)**2)

    def x(self):
        return self.real

    def y(self):
        return self.imag

p1 = Point2d(3, 3)
p2 = Point2d(6, 7)
p3 = p1 + p2
p4 = p3.distanceTo(p1)
print(p4)

# 9.219544457292887

这里发生的是,它只检查基类的所有方法,如果返回的是基类的类型,则将其转换为子类,这由子类中的特殊类方法定义课。

答案 3 :(得分:0)

通常,与继承相比,更喜欢使用组合。您可以根据复数实现所有所需的操作。

class Point2D:
    def __init__(self, x, y):
        self._p = complex(x,y)

    @classmethod
    def _from_complex(self, z):
        return Point2D(z.real, z.imag)

    def __add__(self, other):
        return Point2D._from_complex(self._p + other._p)

    # etc

很多样板吗?是。但这并不是您可以避免的样板。