我可以创建一个本地numpy随机种子吗?

时间:2018-03-29 12:22:59

标签: python numpy random scope

有一个使用foo功能的函数np.random。 我想控制foo使用的种子,但实际上没有改变函数本身。 我该怎么做?

基本上我想要这样的东西:

bar() # should have normal seed
with np.random.seed(0): # Doesn't work
    foo()
bar() # should have normal seed

解决方案如 this

rng = random.Random(42)
number = rng.randint(10, 20)

在这种情况下不起作用,因为我无法访问foo的内部工作(或者我错过了什么?)。

2 个答案:

答案 0 :(得分:6)

您可以将全局随机状态保存在临时变量中,并在完成功能后重置它:

import contextlib
import numpy as np

@contextlib.contextmanager
def temp_seed(seed):
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)

演示:

>>> np.random.seed(0)
>>> np.random.randn(3)
array([1.76405235, 0.40015721, 0.97873798])
>>> np.random.randn(3)
array([ 2.2408932 ,  1.86755799, -0.97727788])

>>> np.random.seed(0)
>>> np.random.randn(3)
array([1.76405235, 0.40015721, 0.97873798])
>>> with temp_seed(5):
...     np.random.randn(3)                                                                                        
array([ 0.44122749, -0.33087015,  2.43077119])
>>> np.random.randn(3)
array([ 2.2408932 ,  1.86755799, -0.97727788])

答案 1 :(得分:1)

我认为这个想法是,当给出一个起始种子时,对bar()的调用总是应该看到相同的随机数序列;无论在{之间插入foo()的次数是多少。

我们可以通过从临时种子状态完成时用于重新种子的随机状态创建随机种子来实现此目的。这可以包含在上下文管理器中:

import numpy as np

class temporary_seed:
    def __init__(self, seed):
        self.seed = seed
        self.backup = None

    def __enter__(self):
        self.backup = np.random.randint(2**32-1, dtype=np.uint32)
        np.random.seed(self.seed)

    def __exit__(self, *_):
        np.random.seed(self.backup)

让我们试试

def bar():
    print('bar:', np.random.randint(10))

def foo():
    print('foo:', np.random.randint(10))

np.random.seed(999)

bar()  # bar: 0
with temporary_seed(42):
    foo()  # foo: 6
    foo()  # foo: 3
bar()  # bar: 9

因此我们得到条序列[0,9]和foo-序列[6,3]。

我们再次尝试不再全球播种:

bar()  # bar: 1
with temporary_seed(42):
    foo()  # foo: 6
    foo()  # foo: 3
bar()  # bar: 2

新的条形序列[1,2]和相同的foo序列[6,3]。

再次使用相同的全球种子,但foo的种子不同:

np.random.seed(999)

bar()  # bar: 0
with temporary_seed(0):
    foo()  # foo: 5
bar()  # bar: 9

这次我们再次获得第一个条形序列[0,9]和不同的foo。尼斯!

那捕获量在哪里?通过进入和离开临时种子部分,我们改变随机状态。我们确定性地这样做并且结果是可重复的,但是如果我们不调用输入temorary_seed我们得到不同的序列:

np.random.seed(999)

bar()  # bar: 0
bar()  # bar: 5

bar-sequence [0,5]而不是[0,9]。如果你能忍受这种限制,这种方法应该有效。