优先级队列插入键值对java

时间:2012-12-03 18:45:18

标签: java priority-queue

背景

我试图在O(mlogn)时间内编码Dijkstra算法,其中m是边数,n是节点数。我用来找到给定起始节点和给定结束节点之间的最短路径。而且我对此很陌生。

以下是我提出的算法:

假设图由邻接矩阵表示,每个节点都有一个行索引。

Initialize starting node distance to zero, and all other nodes to inifinity, in the heap.

Create a list of shortest paths, equal to the number of nodes in the graph, set to 0.

While the index of the node that corresponds to the minimum element in the heap 
has no value in the list of shortest paths and heap has node distances, do:
    Remove the minimum node distance from the heap, and bubble as necessary to fill the removed node.
    Put the minimum node distance into the list of shortest paths at its row index.

    For all nodes that were adjacent to the node with the minimum distance (that was just removed), do:
      Update the distances in the heap for the current node, using the following calculation:
        min((deleted node distance + adjacent edge weight), current node's distance)
    Reorganize the heap to be a minimum heap.

Return value in the list of shortest paths at the location of the end node.

这是O(mlogn)因为你每条边只更新一次距离。

  

“需要线性时间   初始化堆,然后我们以O(log n)的代价执行m次更新,总时间为O(mlog n)。“ - http://www.cs.cmu.edu/~avrim/451f07/lectures/lect1011.pdf

问题

为了更新堆中正确位置的起始顶点的距离,对堆的插入必须是键值对 - 键是节点(行索引),值是距离。 / p>

有演讲幻灯片online表示优先级队列中的每个条目ADT都是一个键值对(否则,它如何优先排序?)。

问题

PriorityQueue的方法最多只有一个参数,那么如何插入与值关联的键?

这必须在具有特定名称的单个文件中完成(即,我的理解是我不能使KeyValuePair类实现Comparator。)

我很想听听你的想法。

4 个答案:

答案 0 :(得分:3)

要为您的应用程序使用JDK的优先级队列实现,除了Map<Key, Value>之外,您还可以维护PriorityQueue<Value>。在您的情况下,Key表示节点,Value是保存节点最短距离的对象。要更新到节点的距离,首先要在地图中查找其对应的距离对象。然后,从优先级队列中删除距离对象。接下来,更新距离对象。最后,将距离对象插回优先级队列。

答案 1 :(得分:0)

我很欣赏我的问题的答案,当时我选择了地图答案,因为由于对语言的理解有限,我似乎更容易实现。

事实证明,我忽略了一个重要的细节,使问题比我想象的要简单得多:如果我维护一个距离数组并将节点插入堆(而不是距离),用作参考在距离数组中,我能够根据它们的值对节点进行排序。

在这个实现中,我毕竟不需要设计键值属性。在更新距离数组中的值之后,我必须删除并重新将这些特定节点添加到堆中,以便堆保持最新并按照@reprogrammer的建议进行排序。

一旦我改变了我放入堆中的内容,该算法与the one found on Wikipedia非常相似。

以下是我最终使用的代码,以防有人遇到同样的问题。注意:神奇的部分是PriorityQueue的创建(类似于@stevevls建议的):

import java.util.*;
import java.io.File; //Because files were used to test correctness.
import java.lang.Math;

public class Dijkstra{

//This value represents infinity.
public static final int MAX_VAL = (int) Math.pow(2,30);

/* Assumptions:
    If G[i][j] == 0, there is no edge between vertex i and vertex j
    If G[i][j] > 1, there is an edge between i and j and the value of G[i][j] is its weight.
    No entry of G will be negative.
*/


static int dijkstra(int[][] G, int i, int j){
    //Get the number of vertices in G
    int n = G.length;

    // The 'i' parameter indicates the starting node and the 'j' parameter
    // is the ending node.


    //Create a list of size n of shortest paths, initialize each entry to infinity
    final int[] shortestPaths = new int[n];

    for(int k = 0; k < n; k++){
        shortestPaths[k] = MAX_VAL;
    }

    //Initialize starting node distance to zero.
    shortestPaths[i] = 0;

    //Make a Priority Queue (a heap)
    PriorityQueue<Integer> PQ = new PriorityQueue<Integer>(n,
        new Comparator<Integer>()
            {
                public int compare(Integer p, Integer q)
                {
                    return shortestPaths[p] - shortestPaths[q];
                }
            } );

    //Populate the heap with the nodes of the graph
    for(int k = 0; k < n; k++){
        PQ.offer(k);
    }

    //While the heap has elements.
    while(PQ.size() > 0){

    //  Remove the minimum node distance from the heap.
        int minimum = PQ.poll();

    //  Check if graph is disconnected, if so, return -1.
        if(shortestPaths[minimum] == MAX_VAL)
            {
                return -1;
            }
    //  End node has been reached (i.e. you've found the shortest path), return the distance.
        if( minimum == j){
            return shortestPaths[j];
        }

    //  Take the current node and look through the row to see the vertices adjacent to it (neighbours)
        for(int columnIt = 0; columnIt < n; columnIt ++){


    //    Update the distances in the heap for the current node, using the following calculation:
    //      min((deleted node distance + adjacent edge weight), current node's distance)

            if(G[minimum][columnIt] > 0){

                int sum = shortestPaths[minimum] + G[minimum][columnIt];

                shortestPaths[columnIt]= Math.min(sum, shortestPaths[columnIt]);

                if(shortestPaths[columnIt]==sum)
                {
                    PQ.remove(columnIt);
                    PQ.offer(columnIt);
                }
            }
        }
    }
    return -1;
}

感谢您的回答和建议。

答案 2 :(得分:0)


我正在解决同样的问题。我知道在哪里可以找到你的答案。 这是一本很棒的书,有代码示例 - Algorithms,第4版,Robert Sedgewick和Kevin Wayne。


Site bookExample of code(包括使用 PriorityQueue Dijkstra 算法的实现) Dijkstra 算法的这种实现并没有使用Java的标准 PriorityQueue 实现。相反,它实现了IndexMinPQ,本书前面已经讨论了详细解释!

答案 3 :(得分:0)

以下是使用priority_queue的Dijkstra实现。 这里忽略了InputReader类,因为它用于快速输入。我们可以根据键值对中的“值”保持优先级。然后选择具有最低成本的对,即值。

import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.InputMismatchException;
import java.util.List;
import java.util.PriorityQueue;
/** 
 * By: Rajan Parmar
 * At : HackerRank 
 **/
public class Dijkstra { 
     // node ,pair ( neighbor , cost)
    static HashMap < Integer , HashSet <Pair>> node;
    static PrintWriter w;   
    public static void main(String [] s) throws Exception{
        InputReader in;

        boolean online = false;
        String fileName = "input";

        node = new HashMap<Integer, HashSet<Pair>>();


        //ignore online if false it is for online competition
        if (online) {

            //ignore
            in = new InputReader(new FileInputStream(
                    new File(fileName + ".txt")));
            w = new PrintWriter(new FileWriter(fileName + "Output.txt"));
        } else {

            // for fast input output . You can use any input method
            in = new InputReader(System.in);
            w = new PrintWriter(System.out);
        }

        // Actual code starts here
        int t;
        int n, m;
        t = in.nextInt();

        while(t-- > 0){
            n = in.nextInt();
            m = in.nextInt();
            while(m-- > 0){
                int x,y,cost;
                x = in.nextInt();
                y = in.nextInt();
                cost = in.nextInt();

                if(node.get(x)==null){
                    node.put(x, new HashSet());
                    node.get(x).add(new Pair(y,cost));
                }
                else{
                    node.get(x).add(new Pair(y,cost));
                }
                if(node.get(y)==null){
                    node.put(y, new HashSet());
                    node.get(y).add(new Pair(x,cost));
                }
                else{
                    node.get(y).add(new Pair(x,cost));
                }
            }
            int source = in.nextInt();
            Dijkstra(source,n);
            node.clear();
            System.out.println("");
        }
    }

    static void Dijkstra(int start , int n) {

        int dist[] = new int[3001];
        int visited[] = new int[3001];
        Arrays.fill(dist, Integer.MAX_VALUE);
        Arrays.fill(visited, 0);
        dist[start] = 0 ;
        PriorityQueue < Pair > pq = new PriorityQueue();

        //this will be prioritized according to VALUES (i.e cost in class Pair)
        pq.add(new Pair(start , 0));
        while(!pq.isEmpty()){
            Pair pr = pq.remove();
            visited[pr.neighbor] = 1;
            for(Pair p:node.get(pr.neighbor)){
                if(dist[p.neighbor] > dist[pr.neighbor] + p.cost){
                    dist[p.neighbor] = dist[pr.neighbor] + p.cost;

                    //add updates cost to vertex through start vertex
                    if(visited[p.neighbor]==0)
                        pq.add(new Pair(p.neighbor ,dist[p.neighbor] ));
                }

            }
        }
        for(int i=1;i<=n;i++){
            if(i==start) continue;
            if(visited[i]==0)
                dist[i]=-1;
            System.out.print(dist[i]+" ");
        }
    }

    static class Pair implements Comparable {

        int neighbor;
        int cost;

        public Pair(int y, int cost) {
            // TODO Auto-generated constructor stub
            neighbor = y;
            this.cost = cost;
        }

        @Override
        public int compareTo(Object o) {
            // TODO Auto-generated method stub
            Pair pr = (Pair)o;

            if(cost > pr.cost)
                return 1;
            else
                return -1;

        }

    }

    //Ignore this class , it is for fast input.
    static class InputReader {

        private InputStream stream;
        private byte[] buf = new byte[8192];
        private int curChar, snumChars;
        private SpaceCharFilter filter;

        public InputReader(InputStream stream) {
            this.stream = stream;
        }

        public int snext() {
            if (snumChars == -1)
                throw new InputMismatchException();
            if (curChar >= snumChars) {
                curChar = 0;
                try {
                    snumChars = stream.read(buf);
                } catch (IOException e) {
                    throw new InputMismatchException();
                }
                if (snumChars <= 0)
                    return -1;
            }
            return buf[curChar++];
        }

        public int nextInt() {
            int c = snext();
            while (isSpaceChar(c))
                c = snext();
            int sgn = 1;
            if (c == '-') {
                sgn = -1;
                c = snext();
            }
            int res = 0;
            do {
                if (c < '0' || c > '9')
                    throw new InputMismatchException();
                res *= 10;
                res += c - '0';
                c = snext();
            } while (!isSpaceChar(c));
            return res * sgn;
        }

        public long nextLong() {
            int c = snext();
            while (isSpaceChar(c))
                c = snext();
            int sgn = 1;
            if (c == '-') {
                sgn = -1;
                c = snext();
            }
            long res = 0;
            do {
                if (c < '0' || c > '9')
                    throw new InputMismatchException();
                res *= 10;
                res += c - '0';
                c = snext();
            } while (!isSpaceChar(c));
            return res * sgn;
        }

        public int[] nextIntArray(int n) {
            int a[] = new int[n];
            for (int i = 0; i < n; i++)
                a[i] = nextInt();
            return a;
        }

        public String readString() {
            int c = snext();
            while (isSpaceChar(c))
                c = snext();
            StringBuilder res = new StringBuilder();
            do {
                res.appendCodePoint(c);
                c = snext();
            } while (!isSpaceChar(c));
            return res.toString();
        }

        public boolean isSpaceChar(int c) {
            if (filter != null)
                return filter.isSpaceChar(c);
            return c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == -1;
        }

        public interface SpaceCharFilter {
            public boolean isSpaceChar(int ch);
        }
    }
}

这将采用以下格式输入。

第一行是T(测试案例编号)。

对于每个测试用例,下一行输入将是N和M,其中N不是节点,M不是边缘。

下一个M行包含3个整数,即x,y,W。它表示节点x和y之间的边缘,权重为W。

下一行包含单个整数,即Source节点。

输出:

从给定的源节点打印到所有节点的最短距离。如果节点无法访问,则打印-1。

e.g

输入:

1
6 8
1 2 1
1 5 4
2 5 2
2 3 2
5 6 5
3 6 2
3 4 1
6 4 3
1

输出:(节点1的所有节点的最短距离)

1 3 4 3 5