在python中使运算符重载更少冗余?

时间:2018-03-06 23:44:23

标签: python python-3.x operator-overloading

我正在编写一个重载列表类型的类。 我刚写了这篇文章,我想知道是否还有其他方法可以减少冗余:

class Vector:
    def __mul__(self, other):
        #Vector([1, 2, 3]) * 5 => Vector([5, 10, 15])
        if isinstance(other, int) or isinstance(other, float):
            tmp = list()
            for i in self.l:
                tmp.append(i * other)
            return Vector(tmp)
        raise VectorException("We can only mul a Vector by a scalar")

    def __truediv__(self, other):
        #Vector([1, 2, 3]) / 5 => Vector([0.2, 0.4, 0.6])
        if isinstance(other, int) or isinstance(other, float):
            tmp = list()
            for i in self.l:
                tmp.append(i / other)
            return Vector(tmp)
        raise VectorException("We can only div a Vector by a Scalar")

    def __floordiv__(self, other):
        #Vector([1, 2, 3]) // 2 => Vector([0, 1, 1])
        if isinstance(other, int) or isinstance(other, float):
            tmp = list()
            for i in self.l:
                tmp.append(i // other)
            return Vector(tmp)
        raise VectorException("We can only div a Vector by a Scalar")

正如您所看到的,每个重载方法都是前一个的复制/粘贴,只需稍加改动。

4 个答案:

答案 0 :(得分:19)

使用decorator design patternlambda function对代码进行分解:

class Vector:
    def __do_it(self, other, func):
        if isinstance(other, int) or isinstance(other, float):
            tmp = list()
            for i in self.l:
                tmp.append(func(i, other))
            return Vector(tmp)
        raise ValueError("We can only operate a Vector by a scalar")

    def __mul__(self, other):
        return self.__do_it(other, lambda i, o: i * o)

    def __truediv__(self, other):
        return self.__do_it(other, lambda i, o: i / o)

    def __floordiv__(self, other):
        return self.__do_it(other, lambda i, o: i // o)

答案 1 :(得分:18)

您要在此处动态生成方法。有多种方法可以做到这一点,从超级动态到动态创建它们在元类__getattribute__中(虽然这对某些特殊方法不起作用 - 参见the docs) 生成源文本以保存在.py文件中,然后您可以import。但最简单的解决方案是在类定义中创建它们,如下所示:

class MyVector:
    # ...

    def _make_op_method(op):
        def _op(self, other):
            if isinstance(other, int) or isinstance(other, float):
                tmp = list()
                for i in self.l:
                    tmp.append(op(i. other))
                return Vector(tmp)
            raise VectorException("We can only {} a Vector by a scalar".format(
                op.__name__.strip('_'))
        _.op.__name__ = op.__name__
        return _op

    __mul__ = _make_op(operator.__mul__)
    __truediv__ = _make_op(operator.__truediv__)
    # and so on

您可以更好地将_op.__doc__设置为您生成的相应文档字符串(请参阅stdlib中的functools.wraps了解相关代码),然后构建__rmul__和{ {1}}与构建__imul__的方式相同,依此类推。你可以编写一个元类,类装饰器或函数生成器来包装一些细节,如果你要做同样的事情的许多变化。但这是基本的想法。

事实上,将它移到课堂之外可以更容易地消除更多的重复。只需在类中定义__mul__方法而不是_op(self, other, op)内的本地方法,并使用_make_op修饰类,您可以这样定义:

@numeric_ops

如果你看一下,例如def numeric_ops(cls): for op in "mul truediv floordiv".split(): # "mul truediv floordiv ... ".split(): def _op(self, other): return self._op(other, getattr(operator, op) _op.__name__ = f"__{op}__" setattr(cls, f"__{op}__", _op) ,它会做类似的事情来生成任何缺少的排序操作。

functions.total_ordering等来自stdlib中的operator模块 - 它们只是简单的函数,其中operator.mul基本上只调用operator.__mul__(x, y),依此类推,当您需要将运算符表达式作为函数传递时。

stdlib中有一些这类代码的例子 - 尽管有更多相关但更简单x * y的例子。

这里的关键是,您使用__rmul__ = __mul__创建的名称与通过def分配创建的名称之间没有区别。无论哪种方式,=都成为类的属性,它的值是一个完成你想要的功能。 (同样地,在类定义期间创建的名称与之后注入的名称之间几乎没有区别。)

那么,你应该这样做吗?

嗯,DRY非常重要。如果你复制 - 粘贴 - 编辑十几次,那么你不可能搞砸其中一个编辑并最终得到一个实际上是倍数的mod方法(以及一个没有捕获它的单元测试)。然后,如果你后来发现了你复制和粘贴的实现中的一个缺陷(在问题的原始版本和编辑版本之间),你必须在十几个地方修复同样的缺陷,这是另一个潜在的错误磁铁。

另一方面,可读性很重要。如果你不明白它是如何起作用的,你可能不应该这样做,并且应该满足于斋戒的答案。 (它不是那么紧凑,也不是那么高效,但它肯定更容易理解。)毕竟,如果代码对您来说不明显,那么您只需要修复一次而不是十几次的错误就会被淹没因为你不知道如何解决它。即使你明白这一点,聪明的代价通常也会超过DRY的好处。

我认为__mul__显示了您想要绘制线条的位置。如果你这样做一次,你最好重复一遍,但如果你是为多个班或多个项目做的那样,你可能最好把聪明抽象成一个你可以写一次的库,用各种不同的类进行详尽的测试,然后反复使用。

答案 2 :(得分:9)

您的代码可以像下面一样紧凑(juanpa.arrivillaga建议return NotImplemented而不是引发异常):

def __mul__(self, other):
    #Vector([1, 2, 3]) * 5 => Vector([5, 10, 15])
    try:
        return Vector([i * other for i in self.l])
    except TypeError:
        return NotImplemented

答案 3 :(得分:7)

战略模式是你的朋友。我还将介绍其他几种清理代码的方法。

您可以在此处阅读有关策略模式的信息:https://en.wikipedia.org/wiki/Strategy_pattern

你说"正如你所看到的,每个重载方法都是前一个的复制/粘贴,只需要很小的改动。"这是您使用此模式的提示。如果你可以将一个小变化变成一个函数,那么你可以编写样板代码一次并关注有趣的部分。

class Vector:
    def _arithmitize(self, other, f, error_msg):
        if isinstance(other, int) or isinstance(other, float):
            tmp = list()
            for a in self.l:
                tmp.append(func(a, other))
            return Vector(tmp)
        raise ValueError(error_msg)

    def _err_msg(self, op_name):
        return "We can only {} a vector by a scalar".format(opp_name)

    def __mul__(self, other):
        return self._arithmitize(
            other, 
            lambda a, b: a * b, 
            self._err_msg('mul'))

    def __div__(self, other):
        return self._arithmitize(
            other, 
            lambda a, b: a / b, 
            self._err_msg('div'))
    # and so on ...

我们可以通过列表理解来清理这一点

class Vector:
    def _arithmetize(self, other, f, error_msg):
        if isinstance(other, int) or isinstance(other, float):
            return Vector([f(a, other) for a in self.l])
        raise ValueError(error_msg)

    def _err_msg(self, op_name):
        return "We can only {} a vector by a scalar".format(opp_name)

    def __mul__(self, other):
        return self._arithmetize(
            other, 
            lambda a, b: a * b, 
            self._err_msg('mul'))

    def __div__(self, other):
        return self._arithmetize(
            other, 
            lambda a, b: a / b, 
            self._err_msg('div'))

我们可以改进类型检查

import numbers

class Vector:
    def _arithmetize(self, other, f, error_msg):
        if isinstance(other, number.Numbers):
            return Vector([f(a, other) for a in self.l])
        raise ValueError(error_msg)

我们可以使用运算符代替编写lambdas:

import operators as op

class Vector:
    # snip ...
    def __mul__(self, other):
        return self._arithmetize(other, op.mul, self._err_msg('mul'))

所以我们最终得到这样的东西:

import numbers
import operators as op

class Vector(object):
    def _arithmetize(self, other, f, err_msg):
        if isinstance(other, numbers.Number):
            return Vector([f(a, other) for a in self.l])
        raise ValueError(self._error_msg(err_msg))
    def _error_msg(self, msg):
        return "We can only {} a vector by a scalar".format(opp_name)

    def __mul__(self, other):
        return self._arithmetize(op.mul, other, 'mul')

    def __truediv__(self, other):
        return self._arithmetize(op.truediv, other, 'truediv')

    def __floordiv__(self, other):
        return self._arithmetize(op.floordiv, other, 'floordiv')

    def __mod__(self, other):
        return self._arithmetize(op.mod, other, 'mod')

    def __pow__(self, other):
        return self._arithmetize(op.pow, other, 'pow')

还有其他方法可以动态生成这些,但对于像这样的一小部分函数,​​可读性很重要。

如果您需要动态生成这些内容,请尝试以下操作:

class Vector(object):
    def _arithmetize(....):
        # you've seen this already 

    def __getattr__(self, name):
        funcs = {
            '__mul__': op.mul, # note: this may not actually work with dunder methods. YMMV
            '__mod__': op.mod,
            ...
        }
        def g(self, other):
            try:
                return self._arithmetize(funcs[name],...)
             except:
                 raise NotImplementedError(...)
        return g

如果您发现此动态示例不起作用,请查看make operators overloading less redundant in python?,它处理在大多数python实现中动态创建dunder_methods的情况。