我有一个基于NumPy的库,有一些类会重载算术运算。由于大量的错误检查,内部有点毛茸茸,但我遇到了一个严重的问题,我是怎么做的。
该库背后的想法是使程序员使用最少的语法和努力使其非常容易和直观。因此,我希望能够很容易地组合不同数据类型的数组,并简单地将较窄的数据类型转换为更广泛的情况。
例如,如果我有两个数组,一个是dtype float64
而另一个是dtype complex128
,那么在将它们添加到一起时,我想将float64
转换为{{1但是,如果它是complex128
和float64
,我想转换为它。但是,如果它是complex192
和float64
的组合,则两者之间没有有效转换而不会丢失complex64
的精度,因此我希望将两者都转换为{ {1}}。
如果我希望我的库完全可靠,我立即看到了我必须寻找每种类型组合并确定它们最窄的常见加宽类型(想想最不常见的多重)的问题。我不想要将所有内容转换为最宽泛的类型,因为这会非常快速地降低内存效率,而且我经常在内存中存储非常大的数组。
有没有一种好方法可以确定两种NumPy类型之间最窄的常见加宽类型?
答案 0 :(得分:5)
@amaurea有正确的想法;事实上,函数已经存在于numpy中。 请查看result_type和promote_types。
答案 1 :(得分:2)
我认为numpy已经在内部做到这一点,那么如何只是问numpy?
(zeros([],dtype=float64)+zeros([],dtype=complex64)).dtype
=> dtype('complex128')
(zeros([],dtype=float32)+zeros([],dtype=complex64)).dtype
=> dtype('complex64')
您可以将其概括为一个函数:
def common_dtype(dtypes):
return np.sum([np.zeros([],dtype=d) for d in dtypes]).dtype