MATLAB中有效的树实现

时间:2015-11-13 15:24:14

标签: performance matlab data-structures tree

MATLAB中的树类

我在MATLAB中实现树数据结构。将新的子节点添加到树中,分配和更新与节点相关的数据值是我期望执行的典型操作。每个节点都具有与之关联的相同类型的data。删除节点对我来说不是必需的。到目前为止,我已经决定继承自handle类的类实现,以便能够将对节点的引用传递给将修改树的函数。

编辑:12月2日

首先,感谢到目前为止评论和答案中的所有建议。他们已经帮助我改进了我的树类。

有人建议尝试在R2015b中引入digraph。我还没有探索过这个问题,但是因为它不像继承自handle的类那样作为引用参数工作,我对它在我的应用程序中如何工作有点怀疑。在这一点上,我还不清楚使用自定义data对节点和边缘使用它是多么容易。

编辑:(12月3日)有关主要应用程序的更多信息:MCTS

最初,我认为主要应用程序的细节只是微不足道的,但自从阅读评论和@FirefoxMetzger的answer后,我意识到它具有重要意义。

我正在实现一种Monte Carlo tree search算法。搜索树以迭代方式被探索和扩展。维基百科提供了一个很好的过程图形概述: Monte Carlo tree search

在我的应用程序中,我执行了大量的搜索迭代。在每次搜索迭代中,我遍历当前树,从根开始直到叶节点,然后通过添加新节点展开树,并重复。由于该方法基于随机抽样,在每次迭代开始时,我不知道我将在每次迭代时完成哪个叶节点。相反,这是由当前在树中的data个节点以及随机样本的结果共同确定的。无论我在单次迭代中访问哪些节点,都会更新data

示例:我在节点n,它有几个孩子。我需要访问每个孩子中的数据并绘制一个随机样本,以确定我在搜索中移动到下一个孩子。重复此过程,直到到达叶节点。实际上,我通过在根上调用search函数来执行此操作,该函数将决定下一个要扩展的子节点,递归地调用该节点上的search,依此类推,最后在叶节点处返回一个值到达。从递归函数返回时使用此值以更新在搜索迭代期间访问的节点的data

树可能非常不平衡,因此某些分支是非常长的节点链,而其他分支在根级别之后很快终止并且不会进一步扩展。

当前实施

下面是我当前实现的示例,其中包含一些用于添加节点,查询树中节点深度或数量等的成员函数的示例,等等。

classdef stree < handle
    %   A class for a tree object that acts like a reference
    %   parameter.
    %   The tree can be traversed in both directions by using the parent
    %   and children information.
    %   New nodes can be added to the tree. The object will automatically
    %   keep track of the number of nodes in the tree and increment the
    %   storage space as necessary.

    properties (SetAccess = private)
        % Hold the data at each node
        Node = { [] };
        % Index of the parent node. The root of the tree as a parent index
        % equal to 0.
        Parent = 0;
        num_nodes = 0;
        size_increment = 1;
        maxSize = 1;
    end

    methods
        function [obj, root_ID] = stree(data, init_siz)
            % New object with only root content, with specified initial
            % size
            obj.Node = repmat({ data },init_siz,1);
            obj.Parent = zeros(init_siz,1);
            root_ID = 1;
            obj.num_nodes = 1;
            obj.size_increment = init_siz;
            obj.maxSize = numel(obj.Parent);
        end

        function ID = addnode(obj, parent, data)
            % Add child node to specified parent
            if obj.num_nodes < obj.maxSize
                % still have room for data
                idx = obj.num_nodes + 1;
                obj.Node{idx} = data;
                obj.Parent(idx) = parent;
                obj.num_nodes = idx;
            else
                % all preallocated elements are in use, reserve more memory
                obj.Node = [
                    obj.Node
                    repmat({data},obj.size_increment,1)
                    ];

                obj.Parent = [
                    obj.Parent
                    parent
                    zeros(obj.size_increment-1,1)];
                obj.num_nodes = obj.num_nodes + 1;

                obj.maxSize = numel(obj.Parent);

            end
            ID = obj.num_nodes;
        end

        function content = get(obj, ID)
            %% GET  Return the contents of the given node IDs.
            content = [obj.Node{ID}];
        end

        function obj = set(obj, ID, content)
            %% SET  Set the content of given node ID and return the modifed tree.
            obj.Node{ID} = content;
        end

        function IDs = getchildren(obj, ID)
            % GETCHILDREN  Return the list of ID of the children of the given node ID.
            % The list is returned as a line vector.
            IDs = find( obj.Parent(1:obj.num_nodes) == ID );
            IDs = IDs';
        end
        function n = nnodes(obj)
            % NNODES  Return the number of nodes in the tree.
            % Equal to root + those whose parent is not root.
            n = 1 + sum(obj.Parent(1:obj.num_nodes) ~= 0);
            assert( obj.num_nodes == n);
        end

        function flag = isleaf(obj, ID)
            % ISLEAF  Return true if given ID matches a leaf node.
            % A leaf node is a node that has no children.
            flag = ~any( obj.Parent(1:obj.num_nodes) == ID );
        end

        function depth = depth(obj,ID)
            % DEPTH return depth of tree under ID. If ID is not given, use
            % root.
            if nargin == 1
                ID = 0;
            end
            if obj.isleaf(ID)
                depth = 0;
            else
                children = obj.getchildren(ID);
                NC = numel(children);
                d = 0; % Depth from here on out
                for k = 1:NC
                    d = max(d, obj.depth(children(k)));
                end
                depth = 1 + d;
            end
        end
    end
end

然而,有时性能很慢,树上的操作占用了我大部分的计算时间。有哪些具体方法可以提高实施效率?如果有性能提升,甚至可以将实现更改为除handle继承类型之外的其他内容。

使用当前实施的分析结果

由于向树中添加新节点是最典型的操作(连同更新节点的data),因此我做了一些profiling。 我使用Nd=6, Ns=10在以下基准测试代码上运行了探查器。

function T = benchmark(Nd, Ns)
% Tree benchmark. Nd: tree depth, Ns: number of nodes per layer
% Initialize tree
T = stree(rand, 10000);
add_layers(1, Nd);
    function add_layers(node_id, num_layers)
        if num_layers == 0
            return;
        end
        child_id = zeros(Ns,1);
        for s = 1:Ns
            % add child to current node
            child_id(s) = T.addnode(node_id, rand);

            % recursively increase depth under child_id(s)
            add_layers(child_id(s), num_layers-1);
        end
    end
end

分析器的结果: Profiler results

R2015b性能

已发现R2015b improves the performance of MATLAB's OOP features。我重新评估了上述基准,确实观察到了性能的提高:

R2015b profiler result

所以这已经是个好消息,虽然当然可以接受进一步的改进;)

以不同方式保留内存

评论中也建议使用

obj.Node = [obj.Node; data; cell(obj.size_increment - 1,1)];

使用repmat保留更多内存而不是当前方法。这略微改善了性能。我应该注意我的基准代码是针对虚拟数据的,因为实际的data更复杂,这可能会有所帮助。谢谢! Profiler结果如下:

zeeMonkeez memory reserve style

关于进一步提高绩效的问题

  1. 也许还有另一种方法来维护更高效的树的内存?遗憾的是,我通常不会提前知道树中有多少个节点。
  2. 添加新节点并修改现有节点的data是我在树上执行的最典型的操作。截至目前,它们实际上占用了我主要应用程序的大部分处理时间。对这些功能的任何改进都是最受欢迎的。
  3. 正如最后一点,我理想地希望将实现保持为纯MATLAB。但是,诸如MEX之类的选项或使用某些集成的Java功能可能是可以接受的。

3 个答案:

答案 0 :(得分:9)

TL:DR 您深度复制每次插入时存储的所有数据,将parentNode单元格初始化为大于您预期需要的单元格。

您的数据确实具有树结构,但是您在实现中没有使用它。相反,实现的代码是查找表的计算饥饿版本(实际上是2个表),它存储数据和树的关系数据。

我说的原因如下:

  • 要插入,请调用stree.addnote(parent,data),它将所有数据存储在树对象stree&#39}的字段Node = {}Parent = []
  • 你似乎知道你想要访问树中的哪个元素,因为没有给出搜索代码(如果你使用stree.getchild(ID),我有一些坏消息)
  • 处理完节点后,使用find()(列表搜索
  • )对其进行追溯

这绝不意味着数据的实现笨拙,它甚至可能是最好的,具体取决于您正在做什么。但它确实解释了您的内存分配问题,并提供了有关如何解决它们的提示。

将数据保留为查找表

存储数据的一种方法是保留基础查找表。我只会这样做,如果您知道要修改的第一个元素ID 而不搜索。这种情况允许您通过两个步骤提高结构效率。

首先初始化数组更大然后您希望存储数据。如果超出查找表的容量,则初始化新的容量,其中X字段更大,并且产生旧数据的深拷贝。如果您需要扩展一次或两次capcity(在所有插入期间),这可能不是问题,但在您的情况下,需要进行深层复制才能插入!

其次,我会更改内部结构并合并两个表NodeParent。这样做的原因是代码中的反向传播需要O(depth_from_root * n),其中n是表中的节点数。这是因为find()将遍历每个父级的整个表。

相反,你可以实现与

类似的东西
table = cell(n,1) % n bigger then expected value
end_pointer = 1 % simple pointer to the first free value

function insert(data,parent_ID)
    if end_pointer < numel(table)
        content.data = data;
        content.parent = parent_ID;
        table{end_pointer} = content;
        end_pointer = end_pointer + 1;
    else
        % need more space, make sure its enough this time
        table = [table cell(end_pointer,1)];
        insert(data,parent_ID);
    end
end

function content = get_value(ID)
    content = table(ID);
end

这使您可以立即访问父{q} ID而无需先find(),每步保存n次迭代,因此负担变为O(深度)。如果你不知道你的初始节点,那么你必须find()那个,这需要花费O(n)。

请注意,此结构不需要is_leaf()depth()nnodes()get_children()。如果您仍然需要这些,我需要更深入地了解您想要对数据做什么,因为这会极大地影响正确的结构。

树形结构

这种结构很有意义,如果你永远不知道第一个节点ID,那么总是必须搜索

好处是搜索任意音符与O(深度)一起工作,因此搜索是O(深度)而不是O(n),反向传播是O(深度^ 2)而不是O(深度+ n) )。请注意,对于完美平衡的树,深度可以是log(n),根据您的数据可以是任意值,对于退化树,n可以是n,只是链接列表。

然而,为了提出一些正确的建议,我需要更多的洞察力,因为每种树形结构都有自己的特色。从目前为止我所看到的情况来看,我建议使用一棵不平衡的树,这是一种排序的&#39;排序&#39;通过节点想要的父节点给出的简单顺序。这可以根据

进一步优化
  • 是否可以定义数据的总订单
  • 你如何对待双重值(相同的数据出现两次)
  • 您的数据规模(千千万万......)
  • 是一个与反向传播配对的查找/搜索
  • 亲子关系的链条有多长?&#39;关于你的数据(或者树如何平衡和深入使用这个简单的顺序)
  • 总是只有一个父母,或者是用不同父母插入两次的同一个元素

我很高兴为上面的树提供示例代码,请给我留言。

编辑: 在你的情况下,一个不平衡的树(这是建立在MCTS上的并列)似乎是最好的选择。以下代码假定数据在statescore中分开,并且state是唯一的。如果不是这样仍然有效,那么可以通过优化来提高MCTS的性能。

classdef node < handle
    % A node for a tree in a MCTS
    properties
        state = {}; %some state of the search space that identifies the node
        score = 0;
        childs = cell(50,1);
        num_childs = 0;
    end
    methods
        function obj = node(state)
            % for a new node simulate a score using MC
            obj.score = simulate_from(state); % TODO implement simulation state -> finish
            obj.state = state;
        end
        function value = update(obj)
            % update the this node using MC recursively
            if obj.num_childs == numel(obj.childs)
                % there are to many childs, we have to expand the table
                obj.childs = [obj.childs cell(obj.num_childs,1)];
            end
            if obj.do_exploration() || obj.num_childs == 0
                % explore a potential state
                state_to_explore = obj.explore();

                %check if state has already been visited
                terminate = false;
                idx = 1;
                while idx <= obj.num_childs && ~terminate
                    if obj.childs{idx}.state_equals(state_to_explore)
                        terminate = true;
                    end
                    idx = idx + 1;
                end

                %preform the according action based on search
                if idx > obj.num_childs
                    % state has never been visited
                    % this action terminates the update recursion 
                    % and creates a new leaf
                    obj.num_childs = obj.num_childs + 1;
                    obj.childs{obj.num_childs} = node(state_to_explore);
                    value = obj.childs{obj.num_childs}.calculate_value();
                    obj.update_score(value);
                else
                    % state has been visited at least once
                    value = obj.childs{idx}.update();
                    obj.update_score(value);
                end
            else
                % exploit what we know already
                best_idx = 1;
                for idx = 1:obj.num_childs
                    if obj.childs{idx}.score > obj.childs{best_idx}.score
                        best_idx = idx;
                    end
                end
                value = obj.childs{best_idx}.update();
                obj.update_score(value);
            end
            value = obj.calculate_value();
        end
        function state = explore(obj)
            %select a next state to explore, that may or may not be visited
            %TODO
        end
        function bool = do_exploration(obj)
            % decide if this node should be explored or exploited
            %TODO
        end
        function bool = state_equals(obj, test_state)
            % returns true if the nodes state is equal to test_state
            %TODO
        end
        function update_score(obj, value)
            % updates the score based on some value
            %TODO
        end
        function calculate_value(obj)
            % returns the value of this node to update previous nodes
            %TODO
        end
    end
end

关于代码的一些评论:

  • 根据设置,可能不需要obj.calculate_value()。例如。如果它是一个可以通过单独评估孩子的分数来计算的值
  • 如果state可以有多个父项,那么重用笔记对象并将其覆盖在结构中是有意义的
  • 当每个node知道其所有子节点时,可以使用node作为根节点轻松生成子树
  • 搜索树(没有任何更新)是一个简单的递归贪婪搜索
  • 取决于搜索的分支因素,可能值得访问每个可能的子项一次(在节点初始化时),然后执行randsample(obj.childs,1)进行探索,因为这样可以避免复制/重新分配子数组< / LI>
  • parent属性被编码为树以递归方式更新,在完成节点更新后将value传递给父
  • 我重新分配内存的唯一一次是当一个节点有超过50个子节点时,我只为该节点重新分配

这应该运行得更快,因为它只是担心选择树的任何部分而不接触任何其他部分。

答案 1 :(得分:6)

我知道这可能听起来很愚蠢...但是如何保持自由节点的数量而不是节点的总数?这需要与常量(即零)进行比较,这是单一属性访问。

另一个伏都教改进方法是将.maxSize移到.num_nodes附近,并将放在<{em} .Node小区之前。像这样,他们在内存中的位置不会相对于对象的开头发生变化,因为.Node属性的增长(这里的巫术是我猜测MATLAB中对象的内部实现)。

稍后编辑当我使用在属性列表末尾移动的.Node进行分析时,通过扩展.Node属性消耗了大部分执行时间,如预期(5.45秒,相比之下您提到的比较为1.25秒)。

答案 2 :(得分:4)

你可以尝试分配一些与你实际填充的元素数成比例的元素:这是c ++中std :: vector的标准实现

obj.Node = [obj.Node; data; cell(q * obj.num_nodes,1)];

我不记得确切,但在MSCC中q为1,而GCC为0.75。

这是一个使用Java的解决方案。我不太喜欢它,但它确实发挥了作用。我实现了你从维基百科中提取的例子。

import javax.swing.tree.DefaultMutableTreeNode

% Let's create our example tree
top = DefaultMutableTreeNode([11,21])
n1 = DefaultMutableTreeNode([7,10])
top.add(n1)
n2 = DefaultMutableTreeNode([2,4])
n1.add(n2)
n2 = DefaultMutableTreeNode([5,6])
n1.add(n2)
n3 = DefaultMutableTreeNode([2,3])
n2.add(n3)
n3 = DefaultMutableTreeNode([3,3])
n2.add(n3)
n1 = DefaultMutableTreeNode([4,8])
top.add(n1)
n2 = DefaultMutableTreeNode([1,2])
n1.add(n2)
n2 = DefaultMutableTreeNode([2,3])
n1.add(n2)
n2 = DefaultMutableTreeNode([2,3])
n1.add(n2)
n1 = DefaultMutableTreeNode([0,3])
top.add(n1)

% Element to look for, your implementation will be recursive
searching = [0 1 1];
idx = 1;
node(idx) = top;
for item = searching,
    % Java transposes the matrices, remember to transpose back when you are reading
    node(idx).getUserObject()'
    node(idx+1) = node(idx).getChildAt(item);
    idx = idx + 1;
end
node(idx).getUserObject()'

% We made a new test...
newdata = [0, 1]
newnode = DefaultMutableTreeNode(newdata)
% ...so we expand our tree at the last node we searched
node(idx).add(newnode)

% The change has to be propagated (this is where your recursion returns)
for it=length(node):-1:1,
    itnode=node(it);
    val = itnode.getUserObject()'
    newitemdata = val + newdata
    itnode.setUserObject(newitemdata)
end

% Let's see if the new values are correct
searching = [0 1 1 0];
idx = 1;
node(idx) = top;
for item = searching,
    node(idx).getUserObject()'
    node(idx+1) = node(idx).getChildAt(item);
    idx = idx + 1;
end
node(idx).getUserObject()'