我正在尝试优化此功能,它依赖于一个jitclass,目前它没有问题。运行此函数后,它会引发错误,指出由于numba jit不支持np.digitize,因此无法读取x = np.digitize(x1, bins)
行。
def eps_q_learning(env, episodes=500, eps=.5, lr=.8, y=.95, decay_factor=.999):
q_table = np.zeros((26, 2, 2))
bins = np.array([-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
for i in range(episodes):
x1, news, done = env.reset()[0:3]
x = np.digitize(x1, bins)
eps *= decay_factor
if np.random.random() < eps or np.sum(q_table[x, int(news), :]) == 0:
a = np.random.randint(0, 2)
else:
a = np.argmax(q_table[x, int(news), :])
_, news, done, reward = env.step(a)
q_table[x, int(news), a] += reward + lr * (y * np.max(q_table[x, int(news), :]) - q_table[x, int(news), a])
return q_table
错误消息:
Traceback (most recent call last):
File "<input>", line 3, in <module>
File "C:\Users\Coen D. Needell\AppData\Local\Programs\Python\Python36\lib\site-packages\numba\dispatcher.py", line 348, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Coen D. Needell\AppData\Local\Programs\Python\Python36\lib\site-packages\numba\dispatcher.py", line 315, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Coen D. Needell\AppData\Local\Programs\Python\Python36\lib\site-packages\numba\six.py", line 658, in reraise
raise value.with_traceback(tb)
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function digitize>) with argument(s) of type(s): (float64, array(int64, 1d, C))
* parameterized
In definition 0:
All templates rejected with literals.
In definition 1:
All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<built-in function digitize>)
[2] During: typing of call at D:\CODE\woke-gpu\StocksGame.py (97)
File "StocksGame.py", line 97:
def eps_q_learning(env, episodes=500, eps=.5, lr=.8, y=.95, decay_factor=.999):
<source elided>
x1, news, done = env.reset()[0:3]
x = np.digitize(x1, bins)
^
This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.
To see Python/NumPy features supported by the latest release of Numba visit:
http://numba.pydata.org/numba-doc/dev/reference/pysupported.html
and
http://numba.pydata.org/numba-doc/dev/reference/numpysupported.html
For more information about typing errors and how to debug them visit:
http://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile
If you think your code should work with Numba, please report the error message
and traceback, along with a minimal reproducer at:
https://github.com/numba/numba/issues/new
有什么想法吗?这是Windows 10,使用numba版本0.41.0,python版本3.6.7,numpy版本1.15.4。