我的目标是坚持复杂算法的每次迭代的完整状态,该算法还涉及通过pycuda
生成的伪随机数。为了在任意迭代中恢复算法并确定性地重现相同的结果,我需要类似get_state() and set_state() from numpy.random.RandomState
考虑到这一点:
from pycuda.curandom import XORWOWRandomNumberGenerator
gen = XORWOWRandomNumberGenerator()
如何将gen
的完整状态加载到numpy
数组中?
如何根据以前获得的gen
数组重现完全相同的numpy
状态?
答案 0 :(得分:1)
我没有找到开箱即用的解决方案。因此,我从XORWOWRandomNumberGenerator
:
from pycuda.curandom import XORWOWRandomNumberGenerator
import pycuda.driver as drv
import numpy
class PersistableXORWOWRandomNumberGenerator(XORWOWRandomNumberGenerator):
def get_state_size(self):
from pycuda.characterize import sizeof
data_type_size = sizeof(self.state_type, "#include <curand_kernel.h>")
return self.block_count * self.generators_per_block * data_type_size
def get_state(self, ary=None):
if ary is None:
ary = drv.from_device(self.state, (self.get_state_size(),), numpy.uint8)
else:
drv.memcpy_dtoh(ary, self.state)
return ary
def set_state(self, state):
drv.memcpy_htod(self.state, state)
可以像这样使用:
gen = PersistableXORWOWRandomNumberGenerator()
# obtain the random state as a uint8 numpy array
state_as_numpy = gen.get_state()
# set the state from a uint8 numpy array
gen.set_state(state_as_numpy)