numpy.einsum有时会忽略dtype参数

时间:2018-04-01 09:03:15

标签: python numpy numpy-einsum

假设我有两个int8类型的数组。我想以这样的方式使用einsum,所有计算都将以int64完成,但我不想将整个数组转换为int64。 如果我理解正确,这就是dtype参数的用途。但它似乎并不总是奏效。

它按预期工作的示例:

>>> A = np.array([[123, 45],[67,89]], dtype='int8')
>>> np.einsum(A,[0,1],A,[0,1],[1]) ## => integer overflow
array([-94, -38], dtype=int8)
>>> np.einsum(A,[0,1],A,[0,1],[1], dtype='int64') ## => no overflow
array([19618,  9946], dtype=int64)

不能按预期工作的示例:

>>> A = np.array([[123, 45],[67,89]], dtype='int8')
>>> np.einsum(A,[0,1],A,[1,2],[0,2]) ## => integer overflow
array([[-32,  68],
       [124, -72]], dtype=int8)
>>> np.einsum(A,[0,1],A,[1,2],[0,2], dtype='int64') ## => should not overflow, but it does
array([[-32,  68],
       [124, -72]], dtype=int8)

对于这些示例,我在带有numpy 1.14.0的Windows上使用了python 3.6.4

当我尝试其他dtypes时会发生同样的情况,例如float64,即使我使用cast ='unsafe'。

为什么会发生这种情况,我该如何才能使其发挥作用?



更新
einsum默认优化= True,至少在numpy 1.14.0中。 当使用optimize = False时,它按预期工作(尽管对于大型数组来说速度要慢得多):

>>> A = np.array([[123, 45],[67,89]], dtype='int8')
>>> np.einsum(A,[0,1],A,[1,2],[0,2]) ## => integer overflow
array([[-32,  68],
       [124, -72]], dtype=int8)
>>> np.einsum(A,[0,1],A,[1,2],[0,2], dtype='int64', optimize=False)
array([[18144,  9540],
       [14204, 10936]], dtype=int64)


从简短的einsum.py看,似乎当optimize = True时,它会检查是否更好地使用numpy.tensordot而不是einsum实现。如果是(它应该在我的第二个例子中,因为它只是一个常规矩阵乘法),那么它使用tensordot,但它不会将dtype参数传递给它。事实上,tensordot甚至没有dtype论证。


如果这是正确的,那么需要一个后续问题(这可能值得自己发帖?):
我如何矩阵乘以某个dtype的两个矩阵,比如int8,这样所有的计算都将完成,例如int64或float64(因此不会溢出/松散精度,除非int64 / float64也可以,但不必将它们首先转换为所需的类型,而且,操作的实现本身不应该整体转换矩阵,每次只转换小部分(因此操作所需的内存不会大得多)比仅仅保存这些矩阵所需的内存和结果)?
这可以用与numpy.dot相当的效率来完成吗?

0 个答案:

没有答案