我正在使用 MCTS算法为棋盘游戏开发智能代理。 蒙特卡罗树搜索(MCTS)是AI中一种流行的方法,主要用于游戏(如GO,Chess,...)。在此方法中,代理根据状态构建树,这些状态是选择当前状态允许的移动的结果。允许代理在有限时间内搜索树。在此期间,Agent将树扩展到最有希望的节点(用于赢得游戏)。 下图显示了该过程:
有关详细信息,请查看以下链接:
1 - http://www.cameronius.com/research/mcts/about/index.html
在树的根节点中,会有一个变量rootstate
,它显示游戏的当前状态。当我们深入树中时,rootstate
的深度复制用于模拟树状态(未来状态)。
我将此代码用于deepcopy
gamestate
类,因为deepcopy
因为pickle协议问题而无法正常使用cython对象:
cdef class gamestate:
# ... other functions
def __deepcopy__(self,memo_dictionary):
res = gamestate(self.size)
res.PLAYERS = self.PLAYERS
res.size = int(self.size)
res.board = np.array(self.board, dtype=np.int32)
res.white_groups = deepcopy(self.white_groups) # a module which checks if white player has won the game
res.black_groups = deepcopy(self.black_groups) # a module which checks if black player has won the game
# the black_groups and white_groups are also cython objects which the same deepcopy function is implemented for them
# .... etc
return res
每当MCTS迭代开始时,状态的深度复制将存储在内存中。
发生的问题是游戏的开始,
每1秒的迭代次数在2000到3000之间,这是预期的,但随着游戏树的扩展,每1秒的迭代次数降低为1。迭代需要更多时间
完成。
当我检查内存使用情况时,我注意到每次调用代理进行搜索时,从0.6%增加到90%。我在 pure python 中实现了相同的算法,它没有这种类型的问题。所以我猜 __ deepcopy__函数会导致问题。我曾经被建议在here中为cython对象制定我自己的pickle
协议,但我对pickle
模块并不是很熟悉。
任何人都可以建议我使用一些协议来为我的cython对象摆脱这个障碍。
编辑2:
我添加了一些可能有用的代码部分。
以下代码属于类unionfind
的深度查看,用于white_groups
中的black_groups
和gamestate
:
cdef class unionfind:
cdef public:
dict parent
dict rank
dict groups
list ignored
cdef __init__(self):
# initialize variables ...
def __deepcopy__(self, memo_dictionary):
res = unionfind()
res.parent = self.parent
res.rank = self.rank
res.groups = self.groups
res.ignored = self.ignored
return res
这是在允许的时间内运行的搜索功能:
cdef class mctsagent:
def search(time_budget):
cdef int num_rollouts = 0
while (num_rollouts < time_budget):
state_copy = deepcopy(self.rootstate)
node, state = self.select_node(state_copy) # expansion runs inside the select_node function
turn = state.turn()
outcome = self.roll_out(state)
self.backup(node, turn, outcome)
num_rollouts += 1
答案 0 :(得分:1)
这个问题可能是行
res.white_groups = deepcopy(self.white_groups) # a module which checks if white player has won the game
res.black_groups = deepcopy(self.black_groups) # a module which checks if black player has won the game
您应该做的是使用第二个参数deepcopy
调用memo_dictionary
。这是deepcopy
的记录,如果它已经复制了一个对象。没有它deepcopy
最终会多次复制同一个对象(因此使用大量内存)
res.white_groups = deepcopy(self.white_groups, memo_dictionary) # a module which checks if white player has won the game
res.black_groups = deepcopy(self.black_groups, memo_dictionary) # a module which checks if black player has won the game
(编辑:刚看到@Blckknght已经在评论中指出了这一点)
(edit2:unionfind
看起来主要包含Python对象。它可能不是一个很大的值cdef class
而不仅仅是普通的类。而且,你当前的{{1}因为它实际上并没有复制那些字典 - 你应该做__deepcopy__
等等。如果你只是把它作为一个普通的类,这将自动实现)