尝试实例化NumPy RandomState时出现Numba错误

时间:2014-01-27 01:41:05

标签: python numpy numba

我有一节课,我想加速Numba。通过简单地创建具有特定种子的NumPy的RandomState实例,该类为每个实例使用“随机数生成器”(因此我可以在以后复制我的工作)。当我使用Numba的autojit时,我得到一个奇怪的错误,这在“常规”Python中不会出现。

幸运的是,这种行为非常简单。这是一个简单的例子,说明了错误:

from numpy.random import RandomState
from numba import autojit

# ------- This works in "regular Python" ------------

class SillyClass1(object):
    def __init__(self, seed):
        self.RNG = RandomState(seed)
    def draw_uniform(self):
        return self.RNG.uniform(0,1)

test1 = SillyClass1(123456)

test1.draw_uniform()

# Output:
# 0.12696983303810094


# The following code -- exactly the same as above, but with the @autojit 
# decorator, doesn't work, and throws an error which I am having a hard  
# time understanding how to fix:

@autojit
class SillyClass2(object):
    def __init__(self, seed):
        self.RNG = RandomState(seed)
    def draw_uniform(self):
        return self.RNG.uniform(0,1)

test2 = SillyClass2(123456)

test2.draw_uniform()

# Output:
#
# ValueError                                Traceback (most recent call last)
# <ipython-input-86-a18f95c11a1b> in <module>()
#      10 
#      11 
# ---> 12 test2 = SillyClass2(123456)
#      13 
#      14 test2.draw_uniform()
# 
# ...
# 
# ValueError: object of too small depth for desired array

我正在使用Anaconda发行版Ubuntu 13.10。

有什么想法?

编辑:我找到了解决办法,就是简单地使用Python的标准“random.Random”而不是NumPys的“numpy.random.RandomState”

示例:

from random import Random 
@autojit
class SillyClass3(object):
    def __init__(self, seed):
        self.RNG = Random(seed)
    def draw_uniform(self):
        return self.RNG.uniform(0,1)

test3 = SillyClass3(123456)

test3.draw_uniform()

# Output:
# 0.8056271362589

这适用于我的直接申请(虽然其他问题立即出现,欢呼)。

但是,此修复程序不适用于我知道我将需要使用numpy.random.RandomState的未来算法。所以我的问题仍然存在 - 有没有人对Numba中使用numy.random.RandomState的原始错误和/或解决方法有任何见解?

1 个答案:

答案 0 :(得分:0)

一个较晚的答案,但问题是Numba does not support Numpy的RandomState对象(截至2019年7月)。它也不支持使用Python standard library设置随机状态。

但是,似乎确实可以通过在函数中传递seed参数,然后在其中调用random.seed(seed)方法(来自Python的标准库)来在jitted函数中设置Numba的内部PRNG。

从Numba文档中:

  

注意:从非Numba代码(或从对象模式)调用random.seed()   代码)将植入Python随机生成器,而不是Numba随机生成器   发电机。

鉴于Numba支持random.seed,这暗示了以上观点。

此外,如果您以CUDA为目标,则有许多方法可以让您直接使用启用GPU的PRNG进行操作。从Numba documentation看到此页面。