如何优化我的Held-Karp算法的java实现以缩短运行时间?

时间:2016-10-28 12:32:25

标签: java algorithm optimization

我使用Java实现的Held-Karp TSP algorithm算法来解决25个城市的TSP问题。 该计划通过4个城市。

当它在25个城市运行时,它会停止几个小时。我现在使用jVisualVM来查看热点是什么,经过一些优化 98%的时间用于实际计算,而不是Map.contains或Map.get。

所以我想得到你的建议,这是代码:

    private void solve() throws Exception {
        long beginTime = System.currentTimeMillis();
        int counter = 0;

        List<BitSetEndPointID> previousCosts;
        List<BitSetEndPointID> currentCosts;
        //maximum number of elements is c(n,[n/2])
        //To calculate m-set's costs just need to keep (m-1)set's costs
        List<BitSetEndPointID> lastKeys = new ArrayList<BitSetEndPointID>();
        int m;
        if (totalNodes < 10) {
            //for test data, generate them on the fly
            SetUtil3.generateMSet(totalNodes);
        }
        //m=1
        BitSet beginSet = new BitSet();
        beginSet.set(0);
        previousCosts = new ArrayList<BitSetEndPointID>(1);
        BitSetEndPointID beginner = new BitSetEndPointID(beginSet, 0);
        beginner.setCost(0f);
        previousCosts.add(beginner);

        //for m=2 to totalNodes
        for (m = 2; m <= totalNodes; m++) {// sum(m=2..n 's C(n,m)*(m-1)(m-1)) ==> O(n^2 * 2^n)
            //pick m elements from total nodes, the element id is the index of nodeCoordinates
            // the first node is always present

            BitSet[] msets;
            if (totalNodes < 10) {
                msets = SetUtil3.msets[m - 1];
            } else {
                //for real data set, will read from serialized file
                msets = SetUtil3.getMsets(totalNodes, m-1);
            }
            currentCosts = new ArrayList<BitSetEndPointID>(msets.length);
            //System.out.println(m + " sets' size: " + msets.size());
            for (BitSet mset : msets) { //C(n,m) mset
                int[] candidates = allSetBits(mset, m);
                //mset is a BitSet which makes sure begin point 0 comes first
                //so end point candidate begins with 1. candidate[0] is always begin point 0
                for (int i = 1; i < candidates.length; i++) { // m-1 bits are set
                    //set the new last point as j, j must not be the same as begin point 0
                    int j = candidates[i];
                    //middleNodes = mset -{j}
                    BitSet middleNodes = (BitSet) mset.clone();
                    middleNodes.clear(j);
                    //loop through all possible points which are second to the last
                    //and get min(A[S-{j},k] + k->j), k!=j
                    float min = Float.MAX_VALUE;
                    int k;
                    for (int ki = 0; ki < candidates.length; ki++) {// m-1 calculation
                        k = candidates[ki];
                        if (k == j) continue;
                        float middleCost = 0;
                        BitSetEndPointID key = new BitSetEndPointID(middleNodes, k);
                        int index = previousCosts.indexOf(key);
                        if (index != -1) {
                            //System.out.println("get value from  map in m " + m + " y key " + middleNodes);
                            middleCost = previousCosts.get(index).getCost();
                        } else if (k == 0 && !middleNodes.equals(beginSet)) {
                            continue;
                        } else {
                            System.out.println("middleCost not found!");
                            continue;
//                            System.exit(-1);
                        }


                        float lastCost = distances[k][j];
                        float cost = middleCost + lastCost;
                        if (cost < min) {
                            min = cost;
                        }

                        counter++;
                        if (counter % 500000 == 0) {
                            try {
                                Thread.currentThread().sleep(100);
                            } catch (InterruptedException iex) {
                                System.out.println("Who dares interrupt my precious sleep?!");
                            }
                        }
                    }
                    //set the costs for chosen mset and last point j
                    BitSetEndPointID key = new BitSetEndPointID(mset, j);
                    key.setCost(min);
                    currentCosts.add(key);

//                    System.out.println("===========================================>mset " + mset + " and end at " +
//                            j + " 's min cost: " + min);
//                    if (m == totalNodes) {
//                        lastKeys.add(key);
//                    }
                }
            }
            previousCosts = currentCosts;
            System.out.println("...");
        }

        calcLastStop(lastKeys, previousCosts);
        System.out.println(" cost " + (System.currentTimeMillis() - beginTime) / 60000 + " minutes.");
    }


    private void calcLastStop(List<BitSetEndPointID> lastKeys, List<BitSetEndPointID>  costs) {
        //last step, calculate the min(A[S={1..n},k] +k->1)
        float finalMinimum = Float.MAX_VALUE;
        for (BitSetEndPointID key : costs) {
            float middleCost = key.getCost();
            Integer endPoint = key.lastPointID;
            float lastCost = distances[endPoint][0];
            float cost = middleCost + lastCost;
            if (cost < finalMinimum) {
                finalMinimum = cost;
            }
        }
        System.out.println("final result: " + finalMinimum);
    }

1 个答案:

答案 0 :(得分:3)

您可以通过使用基元数组(它可能需要比对象列表更好的内存布局)和直接在位掩码上操作(没有位集或其他对象)来加速代码。这是一些代码(它生成一个随机图,但你可以很容易地改变它,以便它读取你的图形):

import java.io.*;
import java.util.*;

class Main {

    final static float INF = 1e10f;

    public static void main(String[] args) {
        final int n = 25;
        float[][] dist = new float[n][n];
        Random random = new Random();
        for (int i = 0; i < n; i++)
            for (int j = i + 1; j < n; j++)
                dist[i][j] = dist[j][i] = random.nextFloat();
        float[][] dp = new float[n][1 << n];
        for (int i = 0; i < dp.length; i++)
            Arrays.fill(dp[i], INF);
        dp[0][1] = 0.0f;
        for (int mask = 1; mask < (1 << n); mask++) {
            for (int lastNode = 0; lastNode < n; lastNode++) {
                if ((mask & (1 << lastNode)) == 0)
                    continue; 
                for (int nextNode = 0; nextNode < n; nextNode++) {
                    if ((mask & (1 << nextNode)) != 0)
                        continue;
                    dp[nextNode][mask | (1 << nextNode)] = Math.min(
                            dp[nextNode][mask | (1 << nextNode)],
                            dp[lastNode][mask] + dist[lastNode][nextNode]);
                }
            }   
        }
        double res = INF;
        for (int lastNode = 0; lastNode < n; lastNode++)
            res = Math.min(res, dist[lastNode][0] + dp[lastNode][(1 << n) - 1]);
        System.out.println(res);
    }
}

在我的电脑上完成只需几分钟:

time java Main
...
real    2m5.546s
user    2m2.264s
sys     0m1.572s