如何为持久性获取和设置XORWOWRandomNumberGenerator状态?

时间:2018-01-30 12:06:58

标签: python pycuda

我的目标是坚持复杂算法的每次迭代的完整状态,该算法还涉及通过pycuda生成的伪随机数。为了在任意迭代中恢复算法并确定性地重现相同的结果,我需要类似get_state() and set_state() from numpy.random.RandomState

的内容

考虑到这一点:

from pycuda.curandom import XORWOWRandomNumberGenerator
gen = XORWOWRandomNumberGenerator()

如何将gen的完整状态加载到numpy数组中?

如何根据以前获得的gen数组重现完全相同的numpy状态?

1 个答案:

答案 0 :(得分:1)

我没有找到开箱即用的解决方案。因此,我从XORWOWRandomNumberGenerator

派生了以下类
from pycuda.curandom import XORWOWRandomNumberGenerator
import pycuda.driver as drv
import numpy


class PersistableXORWOWRandomNumberGenerator(XORWOWRandomNumberGenerator):
    def get_state_size(self):
        from pycuda.characterize import sizeof
        data_type_size = sizeof(self.state_type, "#include <curand_kernel.h>")
        return self.block_count * self.generators_per_block * data_type_size

    def get_state(self, ary=None):
        if ary is None:
            ary = drv.from_device(self.state, (self.get_state_size(),), numpy.uint8)
        else:
            drv.memcpy_dtoh(ary, self.state)
        return ary

    def set_state(self, state):
        drv.memcpy_htod(self.state, state)

可以像这样使用:

gen = PersistableXORWOWRandomNumberGenerator()

# obtain the random state as a uint8 numpy array
state_as_numpy = gen.get_state()

# set the state from a uint8 numpy array
gen.set_state(state_as_numpy)