以下代码产生不同的输出:
import numpy as np
from numba import njit
@njit
def resh_numba(a):
res = a.transpose(1, 0, 2)
res = res.copy().reshape(2, 6)
return res
x = np.arange(12).reshape(2, 2, 3)
print("numpy")
x_numpy = x.transpose(1, 0, 2).reshape(2, 6)
print(x_numpy)
print("numba:")
x_numba = resh_numba(x)
print(x_numba)
输出:
numpy
[[ 0 1 2 6 7 8]
[ 3 4 5 9 10 11]]
numba:
[[ 0 4 8 2 6 10]
[ 1 5 9 3 7 11]]
这是什么原因?我怀疑某个地方发生了order='C'
与order='F'
的冲突,但我希望numpy和numba都默认使用order='C'
。
答案 0 :(得分:2)
这是(至少)由于np.ndarray.copy实现导致的一个错误,我在这里打开了一个问题:https://github.com/numba/numba/issues/3557