Prim算法的实现

时间:2009-04-19 23:34:56

标签: java algorithm priority-queue prims-algorithm

对于我的CS类,我需要在Java中实现Prim的算法,并且我遇到了优先级队列步骤的问题。我有优先级队列的经验,并且理解他们的工作一般,但我遇到了一个特定的步骤。

Prim(G,w,r)
  For each u in V[G]
    do key[u] ← ∞ 
       π[u] ← NIL  
  key[r] ← 0
  Q ← V[G]  
  While Q ≠ Ø
    do u ← EXTRACT-MIN(Q)
       for each v in Adj[u]
            if v is in Q and w(u,v) < key[v]
                 then π[v] ← u
                       key[v] ← w(u,v)

我创建了一个Node类,其中包含键值(我假设是连接到Node的最轻边缘)和父节点。我的问题是我不明白将节点添加到优先级队列。将父节点设置为NIL且密钥为∞时,将所有节点添加到优先级队列对我来说没有意义。

3 个答案:

答案 0 :(得分:2)

在您的问题中的伪代码中,key[u]π[u]是表示算法完成时G的最小生成树的值集。在算法开始时,这些值分别初始化为NIL,表示尚未向MST添加任何顶点。下一步设置根元素(key[r] ← 0)。

优先级队列Q是来自keyπ的单独数据结构。 Q应使用原始图G中的所有顶点进行初始化,不是key和π中的值。请注意,您需要的信息不仅仅是每个顶点最轻的边和父节点,因为您需要知道从Q中提取的每个顶点旁边的所有顶点。

答案 1 :(得分:2)

如果您不想使用PriorityQueue,here是我在Java中实现的Heap ..您可以用MinHeap替换PriorityQueue。

package algo2;

import java.io.DataInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
import java.util.PriorityQueue;

public class Prims {

private static final class GNode implements Comparable<GNode> {
    // unique id of the node
    int id;

    // map of child nodes and their respective distance from current node 'id'
    Map<GNode, Integer> children = new HashMap<GNode, Integer>();

    // used in future computation, to store minimal/optimal updated distance
    int distFromParent=0;

    public GNode(int i) {
        this.id=i;
    }

    @Override
    public int compareTo(GNode o) {
        return this.distFromParent-o.distFromParent;
    }

    @Override
    public String toString() {
        return "GNode [id=" + id + ", distFromParent=" + distFromParent
                + "]";
    }
}

static long findLengthOfMST(GNode[] nodes) {
    PriorityQueue<GNode> pq = new PriorityQueue<GNode>();
    boolean[] visited = new boolean[nodes.length];
    boolean[] exited = new boolean[nodes.length];
    pq.add(nodes[1]);
    visited[1] = true;
    long sum = 0;
    int count = 0;
    while (pq.size() > 0) {
        GNode o = pq.poll();
        if (!exited[o.id]) {
            for (GNode n : o.children.keySet()) {
                if (exited[n.id]) {
                    continue;
                }
                if (visited[n.id]) {
                    if (n.distFromParent >= o.children.get(n)) {
                        n.distFromParent = o.children.get(n);
                    }
                } else {
                    visited[n.id] = true;
                    n.distFromParent = o.children.get(n);
                    pq.add(n);
                }
            }
            sum += o.distFromParent;
            exited[o.id] = true;
            count++;
        }
        if (pq.size() == 0) {
            for (int i = 1; i < nodes.length; i++) {
                if (!exited[i]) {
                    pq.add(nodes[i]);
                }
            }
        }
    }
    System.out.println(count);
    return sum;
}

public static void main(String[] args) {
    StdIn s = new StdIn(System.in);
    int V = s.nextInt();
    int E = s.nextInt();
    GNode[] nodes = new GNode[V+1];
    for (int i = 0; i < E; i++) {
        int u = s.nextInt();
        int v = s.nextInt();
        GNode un = nodes[u];
        GNode vn = nodes[v];
        if (un == null) {
            un = new GNode(u);
            nodes[u] = un;
        }
        if (vn == null) {
            vn = new GNode(v);
            nodes[v] = vn;
        }

        int w = s.nextInt();
        un.children.put(vn, w);
        vn.children.put(un, w);
    }
    long len = findLengthOfMST(nodes);
    System.out.println(len);
}

private static class StdIn {
    final private int BUFFER_SIZE = 1 << 17;
    private DataInputStream din;
    private byte[] buffer;
    private int bufferPointer, bytesRead;
    public StdIn(InputStream in) {
    din = new DataInputStream(in);
    buffer = new byte[BUFFER_SIZE];
    bufferPointer = bytesRead = 0;
    }
    public int nextInt() {int ret = 0;byte c = read();while (c <= ' ')c = read();boolean neg = c == '-';if (neg)c=read();do{ret=ret*10+c-'0';c = read();} while (c>' ');if(neg)return -ret;return ret;}
    private void fillBuffer(){try{bytesRead=din.read(buffer,bufferPointer=0,BUFFER_SIZE);}catch(Exception e) {}if(bytesRead==-1)buffer[0]=-1;}
    private byte read(){if(bufferPointer == bytesRead)fillBuffer();return buffer[bufferPointer++];}
    }
}

答案 2 :(得分:1)

您无需担心将所有节点添加到优先级队列,即使它们具有无限密钥;它们最终将在您的伪代码的最后一行被DECREASE_KEY降低。无论如何你都需要这个操作,所以没有理由不简化你的生活并一次性插入它们。

我看到你的伪代码只有一个问题,就是它在断开连接的图上会表现得很奇怪。