在numpy中重新定义* =运算符

时间:2016-07-30 12:18:33

标签: python arrays numpy subclass

作为mentioned herehere,这在numpy 1.7+中不再起作用了:

import numpy
A = numpy.array([1, 2, 3, 4], dtype=numpy.int16)
B = numpy.array([0.5, 2.1, 3, 4], dtype=numpy.float64)
A *= B

解决方法是:

def mult(a,b):
    numpy.multiply(a, b, out=a, casting="unsafe")

def add(a,b):
    numpy.add(a, b, out=a, casting="unsafe")

mult(A,B)

但是为每个矩阵操作编写的时间太长了!

默认情况下如何覆盖numpy *=运算符?

我应该继承一些东西吗?

2 个答案:

答案 0 :(得分:6)

您可以使用np.set_numeric_ops覆盖数组算术方法:

import numpy as np

def unsafe_multiply(a, b, out=None):
    return np.multiply(a, b, out=out, casting="unsafe")

np.set_numeric_ops(multiply=unsafe_multiply)

A = np.array([1, 2, 3, 4], dtype=np.int16)
B = np.array([0.5, 2.1, 3, 4], dtype=np.float64)
A *= B

print(repr(A))
# array([ 0,  4,  9, 16], dtype=int16)

答案 1 :(得分:1)

您可以创建常规函数并将预期属性传递给它:

def calX(a,b, attr):
    try:
        return getattr(numpy, attr)(a, b, out=a, casting="unsafe")
    except AttributeError:
        raise Exception("Please enter a valid attribute")

演示:

>>> import numpy
>>> A = numpy.array([1, 2, 3, 4], dtype=numpy.int16)
>>> B = numpy.array([0.5, 2.1, 3, 4], dtype=numpy.float64)
>>> calX(A, B, 'multiply')
array([ 0,  4,  9, 16], dtype=int16)
>>> calX(A, B, 'subtract')
array([ 0,  1,  6, 12], dtype=int16)

请注意,如果要覆盖结果,只需将函数的返回值分配给第一个矩阵。

A = calX(A, B, 'multiply')