如何覆盖numpy函数,如下例所示(为参数dtype设置不同的默认值;如何找到属于例如numpy.array的父类?)
import numpy as np
a = [[1.0, 2, 3], [1, 2, 3]]
np1 = np.array(a)
print(np1.dtype) # gives float64
# How do I override np.array to setting float16 and avoid the manual operation?
np1.dtype = 'float16'
print(np1.dtype)
# I think it could be something like this
class myarray(np.ndarray):
def array(self):
print('my array')
super(myarray, self).array(dtype = 'float16')
np2 = np.array(a) # myarray is not working ..
print(np2)
print(np2.dtype)
谢谢&亲切的问候
答案 0 :(得分:1)
实例化时应使用myarray
类。
np2 = myarray.array(a) # myarray will work now
答案 1 :(得分:1)
子类化ndarray
并不简单,np.array
无论如何都不是该类的方法。相反,它只是一个返回新数组的模块级函数。你可以将它包装在你自己的函数中:
def myarray(*args, **kwargs):
"""Create an array with forced dtype."""
return np.array(*args, **kwargs, dtype='float16')
print(myarray([1, 2, 3]).dtype) # float16
如果您想更改numpp.array
的行为,可以monkey patch该功能(非常不鼓励):
def myarray(*args, **kwargs):
kwargs['dtype'] = 'float16' # override any dtype argument
return np.core.multiarray.array(*args, **kwargs) # use actual internal function to avoid infinite recursion
np.array = myarray # apply monkey patch
print(np.array([1, 2, 3]).dtype) # float16