在numba jitted nopython函数中,我需要使用另一个数组中的值来索引数组。两个数组都是numpy数组浮点数。
例如
@numba.jit("void(f8[:], f8[:], f8[:])", nopython=True)
def need_a_cast(sources, indices, destinations):
for i in range(indices.size):
destinations[i] = sources[indices[i]]
我的代码不同,但让我们假设这个愚蠢的例子可以重现这个问题(即我不能有int类型的索引)。 AFAIK,我不能在nopython jit函数中使用int(indices [i])或indices [i] .astype(“int”)。
我该怎么做?
答案 0 :(得分:2)
至少使用numba 0.24,你可以做一个简单的演员:
import numpy as np
import numba as nb
@nb.jit(nopython=True)
def need_a_cast(sources, indices, destinations):
for i in range(indices.size):
destinations[i] = sources[int(indices[i])]
sources = np.arange(10, dtype=np.float64)
indices = np.arange(10, dtype=np.float64)
np.random.shuffle(indices)
destinations = np.empty_like(sources)
print indices
need_a_cast(sources, indices, destinations)
print destinations
# Result
# [ 3. 2. 8. 1. 5. 6. 9. 4. 0. 7.]
# [ 3. 2. 8. 1. 5. 6. 9. 4. 0. 7.]
答案 1 :(得分:2)
如果您真的无法使用int(indices[i])
(适用于JoshAdel,也适用于我),您应该可以使用math.trunc
或math.floor
来处理它:
import math
...
destinations[i] = sources[math.trunc(indices[i])] # truncate (py2 and py3)
destinations[i] = sources[math.floor(indices[i])] # round down (only py3)
据我所知, math.floor
仅适用于Python3,因为它在Python2中返回float
。但另一方面,math.trunc
会反复出现负值。