在cython类上使用手动深度复制会导致内存溢出。为什么?

时间:2017-08-26 22:15:14

标签: python-3.x cython pickle deep-copy

我正在使用 MCTS算法为棋盘游戏开发智能代理。 蒙特卡罗树搜索(MCTS)是AI中一种流行的方法,主要用于游戏(如GO,Chess,...)。在此方法中,代理根据状态构建树,这些状态是选择当前状态允许的移动的结果。允许代理在有限时间内搜索树。在此期间,Agent将树扩展到最有希望的节点(用于赢得游戏)。 下图显示了该过程:

MCTS steps in a single iteration (figure from chaslot 2006)

有关详细信息,请查看以下链接:

1 - http://www.cameronius.com/research/mcts/about/index.html

在树的根节点中,会有一个变量rootstate,它显示游戏的当前状态。当我们深入树中时,rootstate的深度复制用于模拟树状态(未来状态)。

我将此代码用于deepcopy gamestate类,因为deepcopy因为pickle协议问题而无法正常使用cython对象:

cdef class gamestate:

# ... other functions

def __deepcopy__(self,memo_dictionary):
    res = gamestate(self.size)
    res.PLAYERS = self.PLAYERS
    res.size = int(self.size)
    res.board = np.array(self.board, dtype=np.int32)
    res.white_groups = deepcopy(self.white_groups) # a module which checks if white player has won the game
    res.black_groups = deepcopy(self.black_groups) # a module which checks if black player has won the game
    # the black_groups and white_groups are also cython objects which the same deepcopy function is implemented for them
    # .... etc
    return res

每当MCTS迭代开始时,状态的深度复制将存储在内存中。 发生的问题游戏的开始, 每1秒的迭代次数在2000到3000之间,这是预期的,但随着游戏树的扩展,每1秒的迭代次数降低为1。迭代需要更多时间 完成。
当我检查内存使用情况时,我注意到每次调用代理进行搜索时,从0.6%增加到90%。我在 pure python 中实现了相同的算法,它没有这种类型的问题。所以我猜 __ deepcopy__函数会导致问题。我曾经被建议在here中为cython对象制定我自己的pickle协议,但我对pickle模块并不是很熟悉。 任何人都可以建议我使用一些协议来为我的cython对象摆脱这个障碍。

编辑2:

我添加了一些可能有用的代码部分。 以下代码属于类unionfind深度查看,用于white_groups中的black_groupsgamestate

cdef class unionfind:
    cdef public:
        dict parent
        dict rank
        dict groups
        list ignored
    cdef __init__(self):
    # initialize variables ...

   def __deepcopy__(self, memo_dictionary):
       res = unionfind()
       res.parent = self.parent
       res.rank = self.rank
       res.groups = self.groups
       res.ignored = self.ignored
       return res

这是在允许的时间内运行的搜索功能:

cdef class mctsagent:
    def search(time_budget):
        cdef int num_rollouts = 0
        while (num_rollouts < time_budget):
          state_copy = deepcopy(self.rootstate)
          node, state = self.select_node(state_copy) # expansion runs inside the select_node function
          turn = state.turn()
          outcome = self.roll_out(state)
          self.backup(node, turn, outcome)
          num_rollouts += 1

1 个答案:

答案 0 :(得分:1)

这个问题可能是行

res.white_groups = deepcopy(self.white_groups) # a module which checks if white player has won the game
res.black_groups = deepcopy(self.black_groups) # a module which checks if black player has won the game

您应该做的是使用第二个参数deepcopy调用memo_dictionary。这是deepcopy的记录,如果它已经复制了一个对象。没有它deepcopy最终会多次复制同一个对象(因此使用大量内存)

res.white_groups = deepcopy(self.white_groups, memo_dictionary) # a module which checks if white player has won the game
res.black_groups = deepcopy(self.black_groups, memo_dictionary) # a module which checks if black player has won the game
  

If the __deepcopy__() implementation needs to make a deep copy of a component, it should call the deepcopy() function with the component as first argument and the memo dictionary as second argument.

(编辑:刚看到@Blckknght已经在评论中指出了这一点)

(edit2:unionfind看起来主要包含Python对象。它可能不是一个很大的值cdef class而不仅仅是普通的类。而且,你当前的{{1}因为它实际上并没有复制那些字典 - 你应该做__deepcopy__等等。如果你只是把它作为一个普通的类,这将自动实现)