如何覆盖numpy方法

时间:2018-02-27 08:54:38

标签: python-3.x numpy oop

如何覆盖numpy函数,如下例所示(为参数dtype设置不同的默认值;如何找到属于例如numpy.array的父类?)

import numpy as np

a = [[1.0, 2, 3], [1, 2, 3]]

np1 = np.array(a)
print(np1.dtype) # gives float64

# How do I override np.array to setting float16 and avoid the manual operation?
np1.dtype = 'float16'
print(np1.dtype)

# I think it could be something like this
class myarray(np.ndarray):
    def array(self):
        print('my array')
        super(myarray, self).array(dtype = 'float16')

np2 = np.array(a) # myarray is not working ..
print(np2)
print(np2.dtype)

谢谢&亲切的问候

2 个答案:

答案 0 :(得分:1)

实例化时应使用myarray类。

np2 = myarray.array(a) # myarray will work now

答案 1 :(得分:1)

子类化ndarray并不简单,np.array无论如何都不是该类的方法。相反,它只是一个返回新数组的模块级函数。你可以将它包装在你自己的函数中:

def myarray(*args, **kwargs):
    """Create an array with forced dtype."""
    return np.array(*args, **kwargs, dtype='float16')

print(myarray([1, 2, 3]).dtype)  # float16

如果您想更改numpp.array的行为,可以monkey patch该功能(非常不鼓励):

def myarray(*args, **kwargs):
    kwargs['dtype'] = 'float16'  # override any dtype argument
    return np.core.multiarray.array(*args, **kwargs)  # use actual internal function to avoid infinite recursion

np.array = myarray  # apply monkey patch

print(np.array([1, 2, 3]).dtype)  # float16