我在MATLAB中实现树数据结构。将新的子节点添加到树中,分配和更新与节点相关的数据值是我期望执行的典型操作。每个节点都具有与之关联的相同类型的data
。删除节点对我来说不是必需的。到目前为止,我已经决定继承自handle
类的类实现,以便能够将对节点的引用传递给将修改树的函数。
首先,感谢到目前为止评论和答案中的所有建议。他们已经帮助我改进了我的树类。
有人建议尝试在R2015b中引入digraph
。我还没有探索过这个问题,但是因为它不像继承自handle
的类那样作为引用参数工作,我对它在我的应用程序中如何工作有点怀疑。在这一点上,我还不清楚使用自定义data
对节点和边缘使用它是多么容易。
最初,我认为主要应用程序的细节只是微不足道的,但自从阅读评论和@FirefoxMetzger的answer后,我意识到它具有重要意义。
我正在实现一种Monte Carlo tree search算法。搜索树以迭代方式被探索和扩展。维基百科提供了一个很好的过程图形概述:
在我的应用程序中,我执行了大量的搜索迭代。在每次搜索迭代中,我遍历当前树,从根开始直到叶节点,然后通过添加新节点展开树,并重复。由于该方法基于随机抽样,在每次迭代开始时,我不知道我将在每次迭代时完成哪个叶节点。相反,这是由当前在树中的data
个节点以及随机样本的结果共同确定的。无论我在单次迭代中访问哪些节点,都会更新data
。
示例:我在节点n
,它有几个孩子。我需要访问每个孩子中的数据并绘制一个随机样本,以确定我在搜索中移动到下一个孩子。重复此过程,直到到达叶节点。实际上,我通过在根上调用search
函数来执行此操作,该函数将决定下一个要扩展的子节点,递归地调用该节点上的search
,依此类推,最后在叶节点处返回一个值到达。从递归函数返回时使用此值以更新在搜索迭代期间访问的节点的data
。
树可能非常不平衡,因此某些分支是非常长的节点链,而其他分支在根级别之后很快终止并且不会进一步扩展。
下面是我当前实现的示例,其中包含一些用于添加节点,查询树中节点深度或数量等的成员函数的示例,等等。
classdef stree < handle
% A class for a tree object that acts like a reference
% parameter.
% The tree can be traversed in both directions by using the parent
% and children information.
% New nodes can be added to the tree. The object will automatically
% keep track of the number of nodes in the tree and increment the
% storage space as necessary.
properties (SetAccess = private)
% Hold the data at each node
Node = { [] };
% Index of the parent node. The root of the tree as a parent index
% equal to 0.
Parent = 0;
num_nodes = 0;
size_increment = 1;
maxSize = 1;
end
methods
function [obj, root_ID] = stree(data, init_siz)
% New object with only root content, with specified initial
% size
obj.Node = repmat({ data },init_siz,1);
obj.Parent = zeros(init_siz,1);
root_ID = 1;
obj.num_nodes = 1;
obj.size_increment = init_siz;
obj.maxSize = numel(obj.Parent);
end
function ID = addnode(obj, parent, data)
% Add child node to specified parent
if obj.num_nodes < obj.maxSize
% still have room for data
idx = obj.num_nodes + 1;
obj.Node{idx} = data;
obj.Parent(idx) = parent;
obj.num_nodes = idx;
else
% all preallocated elements are in use, reserve more memory
obj.Node = [
obj.Node
repmat({data},obj.size_increment,1)
];
obj.Parent = [
obj.Parent
parent
zeros(obj.size_increment-1,1)];
obj.num_nodes = obj.num_nodes + 1;
obj.maxSize = numel(obj.Parent);
end
ID = obj.num_nodes;
end
function content = get(obj, ID)
%% GET Return the contents of the given node IDs.
content = [obj.Node{ID}];
end
function obj = set(obj, ID, content)
%% SET Set the content of given node ID and return the modifed tree.
obj.Node{ID} = content;
end
function IDs = getchildren(obj, ID)
% GETCHILDREN Return the list of ID of the children of the given node ID.
% The list is returned as a line vector.
IDs = find( obj.Parent(1:obj.num_nodes) == ID );
IDs = IDs';
end
function n = nnodes(obj)
% NNODES Return the number of nodes in the tree.
% Equal to root + those whose parent is not root.
n = 1 + sum(obj.Parent(1:obj.num_nodes) ~= 0);
assert( obj.num_nodes == n);
end
function flag = isleaf(obj, ID)
% ISLEAF Return true if given ID matches a leaf node.
% A leaf node is a node that has no children.
flag = ~any( obj.Parent(1:obj.num_nodes) == ID );
end
function depth = depth(obj,ID)
% DEPTH return depth of tree under ID. If ID is not given, use
% root.
if nargin == 1
ID = 0;
end
if obj.isleaf(ID)
depth = 0;
else
children = obj.getchildren(ID);
NC = numel(children);
d = 0; % Depth from here on out
for k = 1:NC
d = max(d, obj.depth(children(k)));
end
depth = 1 + d;
end
end
end
end
然而,有时性能很慢,树上的操作占用了我大部分的计算时间。有哪些具体方法可以提高实施效率?如果有性能提升,甚至可以将实现更改为除handle
继承类型之外的其他内容。
由于向树中添加新节点是最典型的操作(连同更新节点的data
),因此我做了一些profiling。
我使用Nd=6, Ns=10
在以下基准测试代码上运行了探查器。
function T = benchmark(Nd, Ns)
% Tree benchmark. Nd: tree depth, Ns: number of nodes per layer
% Initialize tree
T = stree(rand, 10000);
add_layers(1, Nd);
function add_layers(node_id, num_layers)
if num_layers == 0
return;
end
child_id = zeros(Ns,1);
for s = 1:Ns
% add child to current node
child_id(s) = T.addnode(node_id, rand);
% recursively increase depth under child_id(s)
add_layers(child_id(s), num_layers-1);
end
end
end
已发现R2015b improves the performance of MATLAB's OOP features。我重新评估了上述基准,确实观察到了性能的提高:
所以这已经是个好消息,虽然当然可以接受进一步的改进;)
以不同方式保留内存
评论中也建议使用
obj.Node = [obj.Node; data; cell(obj.size_increment - 1,1)];
使用repmat
保留更多内存而不是当前方法。这略微改善了性能。我应该注意我的基准代码是针对虚拟数据的,因为实际的data
更复杂,这可能会有所帮助。谢谢! Profiler结果如下:
data
是我在树上执行的最典型的操作。截至目前,它们实际上占用了我主要应用程序的大部分处理时间。对这些功能的任何改进都是最受欢迎的。正如最后一点,我理想地希望将实现保持为纯MATLAB。但是,诸如MEX之类的选项或使用某些集成的Java功能可能是可以接受的。
答案 0 :(得分:9)
TL:DR 您深度复制每次插入时存储的所有数据,将parent
和Node
单元格初始化为大于您预期需要的单元格。
您的数据确实具有树结构,但是您在实现中没有使用它。相反,实现的代码是查找表的计算饥饿版本(实际上是2个表),它存储数据和树的关系数据。
我说的原因如下:
stree
&#39}的字段Node = {}
和Parent = []
find()
(列表搜索这绝不意味着数据的实现笨拙,它甚至可能是最好的,具体取决于您正在做什么。但它确实解释了您的内存分配问题,并提供了有关如何解决它们的提示。
存储数据的一种方法是保留基础查找表。我只会这样做,如果您知道要修改的第一个元素ID
而不搜索。这种情况允许您通过两个步骤提高结构效率。
首先初始化数组更大然后您希望存储数据。如果超出查找表的容量,则初始化新的容量,其中X字段更大,并且产生旧数据的深拷贝。如果您需要扩展一次或两次capcity(在所有插入期间),这可能不是问题,但在您的情况下,需要进行深层复制才能插入!
其次,我会更改内部结构并合并两个表Node
和Parent
。这样做的原因是代码中的反向传播需要O(depth_from_root * n),其中n是表中的节点数。这是因为find()将遍历每个父级的整个表。
相反,你可以实现与
类似的东西table = cell(n,1) % n bigger then expected value
end_pointer = 1 % simple pointer to the first free value
function insert(data,parent_ID)
if end_pointer < numel(table)
content.data = data;
content.parent = parent_ID;
table{end_pointer} = content;
end_pointer = end_pointer + 1;
else
% need more space, make sure its enough this time
table = [table cell(end_pointer,1)];
insert(data,parent_ID);
end
end
function content = get_value(ID)
content = table(ID);
end
这使您可以立即访问父{q} ID
而无需先find()
,每步保存n次迭代,因此负担变为O(深度)。如果你不知道你的初始节点,那么你必须find()
那个,这需要花费O(n)。
请注意,此结构不需要is_leaf()
,depth()
,nnodes()
或get_children()
。如果您仍然需要这些,我需要更深入地了解您想要对数据做什么,因为这会极大地影响正确的结构。
这种结构很有意义,如果你永远不知道第一个节点ID
,那么总是必须搜索。
好处是搜索任意音符与O(深度)一起工作,因此搜索是O(深度)而不是O(n),反向传播是O(深度^ 2)而不是O(深度+ n) )。请注意,对于完美平衡的树,深度可以是log(n),根据您的数据可以是任意值,对于退化树,n可以是n,只是链接列表。
然而,为了提出一些正确的建议,我需要更多的洞察力,因为每种树形结构都有自己的特色。从目前为止我所看到的情况来看,我建议使用一棵不平衡的树,这是一种排序的&#39;排序&#39;通过节点想要的父节点给出的简单顺序。这可以根据
进一步优化我很高兴为上面的树提供示例代码,请给我留言。
编辑:
在你的情况下,一个不平衡的树(这是建立在MCTS上的并列)似乎是最好的选择。以下代码假定数据在state
和score
中分开,并且state
是唯一的。如果不是这样仍然有效,那么可以通过优化来提高MCTS的性能。
classdef node < handle
% A node for a tree in a MCTS
properties
state = {}; %some state of the search space that identifies the node
score = 0;
childs = cell(50,1);
num_childs = 0;
end
methods
function obj = node(state)
% for a new node simulate a score using MC
obj.score = simulate_from(state); % TODO implement simulation state -> finish
obj.state = state;
end
function value = update(obj)
% update the this node using MC recursively
if obj.num_childs == numel(obj.childs)
% there are to many childs, we have to expand the table
obj.childs = [obj.childs cell(obj.num_childs,1)];
end
if obj.do_exploration() || obj.num_childs == 0
% explore a potential state
state_to_explore = obj.explore();
%check if state has already been visited
terminate = false;
idx = 1;
while idx <= obj.num_childs && ~terminate
if obj.childs{idx}.state_equals(state_to_explore)
terminate = true;
end
idx = idx + 1;
end
%preform the according action based on search
if idx > obj.num_childs
% state has never been visited
% this action terminates the update recursion
% and creates a new leaf
obj.num_childs = obj.num_childs + 1;
obj.childs{obj.num_childs} = node(state_to_explore);
value = obj.childs{obj.num_childs}.calculate_value();
obj.update_score(value);
else
% state has been visited at least once
value = obj.childs{idx}.update();
obj.update_score(value);
end
else
% exploit what we know already
best_idx = 1;
for idx = 1:obj.num_childs
if obj.childs{idx}.score > obj.childs{best_idx}.score
best_idx = idx;
end
end
value = obj.childs{best_idx}.update();
obj.update_score(value);
end
value = obj.calculate_value();
end
function state = explore(obj)
%select a next state to explore, that may or may not be visited
%TODO
end
function bool = do_exploration(obj)
% decide if this node should be explored or exploited
%TODO
end
function bool = state_equals(obj, test_state)
% returns true if the nodes state is equal to test_state
%TODO
end
function update_score(obj, value)
% updates the score based on some value
%TODO
end
function calculate_value(obj)
% returns the value of this node to update previous nodes
%TODO
end
end
end
关于代码的一些评论:
obj.calculate_value()
。例如。如果它是一个可以通过单独评估孩子的分数来计算的值state
可以有多个父项,那么重用笔记对象并将其覆盖在结构中是有意义的node
知道其所有子节点时,可以使用node
作为根节点轻松生成子树randsample(obj.childs,1)
进行探索,因为这样可以避免复制/重新分配子数组< / LI>
parent
属性被编码为树以递归方式更新,在完成节点更新后将value
传递给父这应该运行得更快,因为它只是担心选择树的任何部分而不接触任何其他部分。
答案 1 :(得分:6)
我知道这可能听起来很愚蠢...但是如何保持自由节点的数量而不是节点的总数?这需要与常量(即零)进行比较,这是单一属性访问。
另一个伏都教改进方法是将.maxSize
移到.num_nodes
附近,并将放在<{em} .Node
小区之前。像这样,他们在内存中的位置不会相对于对象的开头发生变化,因为.Node
属性的增长(这里的巫术是我猜测MATLAB中对象的内部实现)。
稍后编辑当我使用在属性列表末尾移动的.Node
进行分析时,通过扩展.Node
属性消耗了大部分执行时间,如预期(5.45秒,相比之下您提到的比较为1.25秒)。
答案 2 :(得分:4)
你可以尝试分配一些与你实际填充的元素数成比例的元素:这是c ++中std :: vector的标准实现
obj.Node = [obj.Node; data; cell(q * obj.num_nodes,1)];
我不记得确切,但在MSCC中q
为1,而GCC为0.75。
这是一个使用Java的解决方案。我不太喜欢它,但它确实发挥了作用。我实现了你从维基百科中提取的例子。
import javax.swing.tree.DefaultMutableTreeNode
% Let's create our example tree
top = DefaultMutableTreeNode([11,21])
n1 = DefaultMutableTreeNode([7,10])
top.add(n1)
n2 = DefaultMutableTreeNode([2,4])
n1.add(n2)
n2 = DefaultMutableTreeNode([5,6])
n1.add(n2)
n3 = DefaultMutableTreeNode([2,3])
n2.add(n3)
n3 = DefaultMutableTreeNode([3,3])
n2.add(n3)
n1 = DefaultMutableTreeNode([4,8])
top.add(n1)
n2 = DefaultMutableTreeNode([1,2])
n1.add(n2)
n2 = DefaultMutableTreeNode([2,3])
n1.add(n2)
n2 = DefaultMutableTreeNode([2,3])
n1.add(n2)
n1 = DefaultMutableTreeNode([0,3])
top.add(n1)
% Element to look for, your implementation will be recursive
searching = [0 1 1];
idx = 1;
node(idx) = top;
for item = searching,
% Java transposes the matrices, remember to transpose back when you are reading
node(idx).getUserObject()'
node(idx+1) = node(idx).getChildAt(item);
idx = idx + 1;
end
node(idx).getUserObject()'
% We made a new test...
newdata = [0, 1]
newnode = DefaultMutableTreeNode(newdata)
% ...so we expand our tree at the last node we searched
node(idx).add(newnode)
% The change has to be propagated (this is where your recursion returns)
for it=length(node):-1:1,
itnode=node(it);
val = itnode.getUserObject()'
newitemdata = val + newdata
itnode.setUserObject(newitemdata)
end
% Let's see if the new values are correct
searching = [0 1 1 0];
idx = 1;
node(idx) = top;
for item = searching,
node(idx).getUserObject()'
node(idx+1) = node(idx).getChildAt(item);
idx = idx + 1;
end
node(idx).getUserObject()'