从平面数据中深入了解树结构的所有后代

时间:2014-09-28 02:07:12

标签: java algorithm

我有一个代表等级关系的平面数据,如下所示:

ID  Name    PID
0   A       NULL
1   B       0
2   C       0
4   D       1
5   E       1
6   F       4
3   G       0

此表表示'数据表',其中PID表示父元素。 例如,在第一行中我们看到A有PID null而B有PID 0,这意味着B的父是A,因为0是A的ID,A是根元素,因为它没有PID 。类似地,C具有父A,因为C也具有PID 0,并且0是A的ID。

我创建了一个类DataTable来表示上面的表。我还实现了processDataTable方法

public Map<String, List<String>> processDataTable()

返回的地图使用element作为键,并将descendantnodes的集合保存为值。例如,映射中的第一项对应于元素A,它具有许多后代,而元素C没有后代。输出中成员的顺序并不重要。

public static void main(String...arg) {

     DataTable dt = newDataTable();

     dt.addRow(0, "A", null);
     dt.addRow(1, "B", 0);
     dt.addRow(2, "C", 0);
     dt.addRow(4, "D", 1);
     dt.addRow(5, "E", 1);
     dt.addRow(6, "F", 4);
     dt.addRow(3, "G", 0);

     System.out.println("Output:");
     System.out.println(dt.processDataTable());
 }

Output:
{D=[F], A=[B, C, G, D, E, F], B=[D, E, F]}
or
{D=[F], E=null, F=null, G=null, A=[B, C, G, D, E, F], B=[D, E, F], C=null}

以下是我对DataTable的实现:

public class DataTable {

    private List<Record> records = new ArrayList<>();
    private Map<Integer, Integer> indexes = new HashMap<>();
    private static final int PROCESSORS = Runtime.getRuntime().availableProcessors();

    /**
     * Add new record into DataTable.
     * 
     * @param id
     * @param name
     * @param parentId
     */
    public void addRow(Integer id, String name, Integer parentId) {
        if (indexes.get(id) == null) {
            Record rec = new Record(id, name, parentId);
            records.add(rec);
            indexes.put(id, records.size() - 1);
        }
    }

    public List<Record> getRecords() {
       return records;
    }

    /**
     * Process DataTable and return a Map of all keys and its children. The
     * main algorithm here is to divide big record set into multiple parts, compute
     * on multi threads and then merge all result together.
     * 
     * @return
     */
    public Map<String, List<String>> processDataTable() {
       long start = System.currentTimeMillis(); 
       int size = size();

       // Step 1: Link all nodes together
       invokeOnewayTask(new LinkRecordTask(this, 0, size));

       Map<String, List<String>> map = new ConcurrentHashMap<>();

       // Step 2: Get result
       invokeOnewayTask(new BuildChildrenMapTask(this, 0, size, map));

       long elapsedTime = System.currentTimeMillis() - start;

       System.out.println("Total elapsed time: " + elapsedTime + " ms");

       return map;
    }

    /**
     * Invoke given task one way and measure the time to execute.
     * 
     * @param task
     */
    private void invokeOnewayTask(ForkJoinTask<?> task) {
        long start = System.currentTimeMillis();
        ForkJoinPool pool = new ForkJoinPool(PROCESSORS);
        pool.invoke(task);
        long elapsedTime = System.currentTimeMillis() - start;
        System.out.println(task.getClass().getSimpleName() + ":" + elapsedTime + " ms");
    }

    /**
     * Find record by id.
     * 
     * @param id
     * @return
     */
    public Record getRecordById(Integer id) {
        Integer pos = indexes.get(id);
        if (pos != null) {
            return records.get(pos);
        }
        return null;
    }

    /**
     * Find record by row number.
     * 
     * @param rownum
     * @return
     */
    public Record getRecordByRowNumber(Integer rownum) {
       return (rownum < 0 || rownum > records.size() - 1) ? null:records.get(rownum);
    }

    public int size() {
       return records.size();
    }

    /**
     * A task link between nodes
     */
    private static class LinkRecordTask extends RecursiveAction {

    private static final long serialVersionUID = 1L;
    private DataTable dt;
    private int start;
    private int end;
    private int limit = 100;

    public LinkRecordTask(DataTable dt, int start, int end) {
        this.dt = dt;
        this.start = start;
        this.end = end;
    }

    @Override
    protected void compute() {
        if ((end - start) < limit) {
        for (int i = start; i < end; i++) {
            Record r = dt.records.get(i);
            Record parent = dt.getRecordById(r.parentId);
            r.parent = parent;
            if(parent != null) {
               parent.children.add(r);
            }
        }
        } else {
           int mid = (start + end) / 2;
           LinkRecordTask left = new LinkRecordTask(dt, start, mid);
           LinkRecordTask right = new LinkRecordTask(dt, mid, end);
           left.fork();
           right.fork();
           left.join();
           right.join();
        }
    }

    }

    /**
     * Build Map<String, List<String>> result from given DataTable.
     */
    private static class BuildChildrenMapTask extends RecursiveAction {

        private static final long serialVersionUID = 1L;
        private DataTable dt;
        private int start;
        private int end;
        private int limit = 100;
        private Map<String, List<String>> map;

        public BuildChildrenMapTask(DataTable dt, int start, int end, Map<String, List<String>> map) {
            this.dt = dt;
            this.start = start;
            this.end = end;
            this.map = map;
        }

        @Override
        protected void compute() {
            if ((end - start) < limit) {
               computeDirectly();
            } else {
                int mid = (start + end) / 2;
                BuildChildrenMapTask left = new BuildChildrenMapTask(dt, start, mid, map);
                BuildChildrenMapTask right = new BuildChildrenMapTask(dt, mid, end, map);
                left.fork();
                right.fork();
                left.join();
                right.join();
           }
        }

        private void computeDirectly() {  
            for (int i = start; i < end; i++) {
                Record rec = dt.records.get(i);
                List<String> names = new ArrayList<String>();

                loadDeeplyChildNodes(rec, names);

                if(!names.isEmpty()) {
                    map.put(rec.name, names);
                }
            }
        }

        private void loadDeeplyChildNodes(Record r, List<String> names) {
             Collection<Record> children = r.children;
             for(Record rec:children) {
                if(!names.contains(rec.name)) {
                   names.add(rec.name);
                }
                loadDeeplyChildNodes(rec, names);
             }
        }

    }

}

我的记录课:

/**
 * Represents a structure of a record in DataTable.
 */
public class Record {

    public Integer id;
    public String name;
    public Integer parentId;
    public Record parent;
    public Collection<Record> children;

    public Record(Integer id, String name, Integer parentId) {
        this();
        this.id = id;
        this.name = name;
        this.parentId = parentId;
    }

    public Record() {
       children = Collections.newSetFromMap(new ConcurrentHashMap<Record, Boolean>())
    }

    public Collection<Record> getChildren() {
       return children;
    }

    public Record getParent() {
       return parent;
    }

    public Integer getParentId() {
       return parentId;
    }

    @Override
    public String toString() {
        return "Record{" + "id=" + id + ", name=" + name + ", parentId=" + parentId + '}';
    }

    /* (non-Javadoc)
     * @see java.lang.Object#hashCode()
     */
    @Override
    public int hashCode() {
       final int prime = 31;
       int result = 1;
       result = prime * result + ((id == null) ? 0 : id.hashCode());
       result = prime * result + ((name == null) ? 0 : name.hashCode());
       result = prime * result  + ((parentId == null) ? 0 : parentId.hashCode());
       return result;
    }

    /* (non-Javadoc)
     * @see java.lang.Object#equals(java.lang.Object)
     */
    @Override
    public boolean equals(Object obj) {
    if (this == obj) {
        return true;
    }
    if (obj == null) {
        return false;
    }
    if (!(obj instanceof Record)) {
        return false;
    }
    Record other = (Record) obj;
    if (id == null) {
        if (other.id != null) {
        return false;
        }
    } else if (!id.equals(other.id)) {
        return false;
    }
    if (name == null) {
        if (other.name != null) {
        return false;
        }
    } else if (!name.equals(other.name)) {
        return false;
    }
    if (parentId == null) {
        if (other.parentId != null) {
        return false;
        }
    } else if (!parentId.equals(other.parentId)) {
        return false;
    }
    return true;
    }

}

我的算法是:

    - Link all parent and child of each record
    - Build the map 

On each step I apply fork join to divide the dataset into smaller parts and run in parellel.

我不知道这个实现有什么问题。有人可以给我一些建议吗?此实现在案例线性层次结构5K记录上得到OutOfmemory错误(项目1是项目2的根和父项,项目2是项目3的父项,项目3是项目4的父项,......等等)。它得到了OutOfmemory,因为它多次调用递归方法。

这个问题的优秀算法是什么?我应该修改哪个数据结构以使其更好?

1 个答案:

答案 0 :(得分:4)

你似乎已经倾向于祈祷写出比做你想要的更多代码的诱惑。根据您的数据,我们可以编写一个简单的树结构,让您进行祖先和后代搜索:

import java.util.HashMap;
import java.util.ArrayList;

class Node {
  // static lookup table, because we *could* try to find nodes by walking
  // the node tree, but the ids are uniquely identifying: this way we can
  // do an instant lookup. Efficiency!
  static HashMap<Long, Node> NodeLUT = new HashMap<Long, Node>();

  // we could use Node.NodeLUT.get(...), but having a Node.getNode(...) is nicer
  public static Node getNode(long id) {
    return Node.NodeLUT.get(id);
  }

  // we don't call the Node constructor directly, we just let this factory
  // take care of that for us instead.
  public static Node create(long _id, String _label) {
    return new Node(_id, _label);
  }

  public static Node create(long _id, String _label, long _parent) {
    Node parent = Node.NodeLUT.get(_parent), node;
    node = new Node(_id, _label);
    parent.addChild(node);
    return node;
  }

  // instance variables and methods

  Node parent;
  long id;
  String label;
  ArrayList<Node> children = new ArrayList<Node>();

  // again: no public constructor. We can only use Node.create if we want
  // to make Node objects.
  private Node(long _id, String _label) {
    parent = null;
    id = _id;
    label = _label;
    Node.NodeLUT.put(id, this);
  }

  // this is taken care of in Node.create, too
  private void addChild(Node child) {
    children.add(child);
    child.setParent(this);
  }

  // as is this.
  private void setParent(Node _parent) {
    parent = _parent;
  }

  /**
   * Find the route from this node, to some descendant node with id [descendentId]
   */
  public ArrayList<Node> getDescendentPathTo(long descendentId) {
    ArrayList<Node> list = new ArrayList<Node>(), temp;
    list.add(this);
    if(id == descendentId) {
      return list;
    }
    for(Node n: children) {
      temp = n.getDescendentPathTo(descendentId);
      if(temp != null) {
        list.addAll(temp);
        return list;
      }
    }
    return null;
  }

  /**
   * Find the route from this node, to some ancestral node with id [descendentId]
   */
  public ArrayList<Node> getAncestorPathTo(long ancestorId) {
    ArrayList<Node> list = new ArrayList<Node>(), temp;
    list.add(this);
    if(id == ancestorId) {
      return list;
    }
    temp = parent.getAncestorPathTo(ancestorId);
    if(temp != null) {
      list.addAll(temp);
      return list;
    }
    return null;
  }

  public String toString() {
    return "{id:"+id+",label:"+label+"}";
  }
}

因此,让我们通过添加标准public static void main(String[] args)方法来测试它以确保它有效,并且为了方便起见,还有一个函数将Node的ArrayLists转换成可读的东西:

  public static String stringify(ArrayList<?> list) {
    String listString = "";
    for (int s=0, l=list.size(); s<l; s++) {
      listString += list.get(s).toString();
      if(s<l-1) { listString += ", "; }
    }
    return listString;
  }

  public static void main(String[] args) {
    // hard coded data based on your question-supplied example data
    Node.create(0, "A");
    Node.create(1, "B", 0);
    Node.create(2, "C", 0);
    Node.create(4, "D", 1);
    Node.create(5, "E", 1);
    Node.create(6, "F", 4);
    Node.create(3, "G", 0);

    // let's see what we get!
    Node root = Node.getNode(0);
    Node f = Node.getNode(6);
    System.out.println("From root to F: " + stringify(root.getDescendentPathTo(6)));
    System.out.println("From F to root: " + stringify(f.getAncestorPathTo(0)));
  }

输出

From root to F: {id:0,label:A}, {id:1,label:B}, {id:4,label:D}, {id:6,label:F}
From F to root: {id:6,label:F}, {id:4,label:D}, {id:1,label:B}, {id:0,label:A}

完美。

所以我们需要做的就是编写一个能够改变你的平面定义的部分&#34;进入Node.create来电,完成了。请记住:不要过于复杂。如果您的数据是扁平树,那么您只需要一个树形结构。编写树结构所需的只是一个Node类。