我使用Java实现的Held-Karp TSP algorithm算法来解决25个城市的TSP问题。 该计划通过4个城市。
当它在25个城市运行时,它会停止几个小时。我现在使用jVisualVM来查看热点是什么,经过一些优化 98%的时间用于实际计算,而不是Map.contains或Map.get。
所以我想得到你的建议,这是代码:
private void solve() throws Exception {
long beginTime = System.currentTimeMillis();
int counter = 0;
List<BitSetEndPointID> previousCosts;
List<BitSetEndPointID> currentCosts;
//maximum number of elements is c(n,[n/2])
//To calculate m-set's costs just need to keep (m-1)set's costs
List<BitSetEndPointID> lastKeys = new ArrayList<BitSetEndPointID>();
int m;
if (totalNodes < 10) {
//for test data, generate them on the fly
SetUtil3.generateMSet(totalNodes);
}
//m=1
BitSet beginSet = new BitSet();
beginSet.set(0);
previousCosts = new ArrayList<BitSetEndPointID>(1);
BitSetEndPointID beginner = new BitSetEndPointID(beginSet, 0);
beginner.setCost(0f);
previousCosts.add(beginner);
//for m=2 to totalNodes
for (m = 2; m <= totalNodes; m++) {// sum(m=2..n 's C(n,m)*(m-1)(m-1)) ==> O(n^2 * 2^n)
//pick m elements from total nodes, the element id is the index of nodeCoordinates
// the first node is always present
BitSet[] msets;
if (totalNodes < 10) {
msets = SetUtil3.msets[m - 1];
} else {
//for real data set, will read from serialized file
msets = SetUtil3.getMsets(totalNodes, m-1);
}
currentCosts = new ArrayList<BitSetEndPointID>(msets.length);
//System.out.println(m + " sets' size: " + msets.size());
for (BitSet mset : msets) { //C(n,m) mset
int[] candidates = allSetBits(mset, m);
//mset is a BitSet which makes sure begin point 0 comes first
//so end point candidate begins with 1. candidate[0] is always begin point 0
for (int i = 1; i < candidates.length; i++) { // m-1 bits are set
//set the new last point as j, j must not be the same as begin point 0
int j = candidates[i];
//middleNodes = mset -{j}
BitSet middleNodes = (BitSet) mset.clone();
middleNodes.clear(j);
//loop through all possible points which are second to the last
//and get min(A[S-{j},k] + k->j), k!=j
float min = Float.MAX_VALUE;
int k;
for (int ki = 0; ki < candidates.length; ki++) {// m-1 calculation
k = candidates[ki];
if (k == j) continue;
float middleCost = 0;
BitSetEndPointID key = new BitSetEndPointID(middleNodes, k);
int index = previousCosts.indexOf(key);
if (index != -1) {
//System.out.println("get value from map in m " + m + " y key " + middleNodes);
middleCost = previousCosts.get(index).getCost();
} else if (k == 0 && !middleNodes.equals(beginSet)) {
continue;
} else {
System.out.println("middleCost not found!");
continue;
// System.exit(-1);
}
float lastCost = distances[k][j];
float cost = middleCost + lastCost;
if (cost < min) {
min = cost;
}
counter++;
if (counter % 500000 == 0) {
try {
Thread.currentThread().sleep(100);
} catch (InterruptedException iex) {
System.out.println("Who dares interrupt my precious sleep?!");
}
}
}
//set the costs for chosen mset and last point j
BitSetEndPointID key = new BitSetEndPointID(mset, j);
key.setCost(min);
currentCosts.add(key);
// System.out.println("===========================================>mset " + mset + " and end at " +
// j + " 's min cost: " + min);
// if (m == totalNodes) {
// lastKeys.add(key);
// }
}
}
previousCosts = currentCosts;
System.out.println("...");
}
calcLastStop(lastKeys, previousCosts);
System.out.println(" cost " + (System.currentTimeMillis() - beginTime) / 60000 + " minutes.");
}
private void calcLastStop(List<BitSetEndPointID> lastKeys, List<BitSetEndPointID> costs) {
//last step, calculate the min(A[S={1..n},k] +k->1)
float finalMinimum = Float.MAX_VALUE;
for (BitSetEndPointID key : costs) {
float middleCost = key.getCost();
Integer endPoint = key.lastPointID;
float lastCost = distances[endPoint][0];
float cost = middleCost + lastCost;
if (cost < finalMinimum) {
finalMinimum = cost;
}
}
System.out.println("final result: " + finalMinimum);
}
答案 0 :(得分:3)
您可以通过使用基元数组(它可能需要比对象列表更好的内存布局)和直接在位掩码上操作(没有位集或其他对象)来加速代码。这是一些代码(它生成一个随机图,但你可以很容易地改变它,以便它读取你的图形):
import java.io.*;
import java.util.*;
class Main {
final static float INF = 1e10f;
public static void main(String[] args) {
final int n = 25;
float[][] dist = new float[n][n];
Random random = new Random();
for (int i = 0; i < n; i++)
for (int j = i + 1; j < n; j++)
dist[i][j] = dist[j][i] = random.nextFloat();
float[][] dp = new float[n][1 << n];
for (int i = 0; i < dp.length; i++)
Arrays.fill(dp[i], INF);
dp[0][1] = 0.0f;
for (int mask = 1; mask < (1 << n); mask++) {
for (int lastNode = 0; lastNode < n; lastNode++) {
if ((mask & (1 << lastNode)) == 0)
continue;
for (int nextNode = 0; nextNode < n; nextNode++) {
if ((mask & (1 << nextNode)) != 0)
continue;
dp[nextNode][mask | (1 << nextNode)] = Math.min(
dp[nextNode][mask | (1 << nextNode)],
dp[lastNode][mask] + dist[lastNode][nextNode]);
}
}
}
double res = INF;
for (int lastNode = 0; lastNode < n; lastNode++)
res = Math.min(res, dist[lastNode][0] + dp[lastNode][(1 << n) - 1]);
System.out.println(res);
}
}
在我的电脑上完成只需几分钟:
time java Main
...
real 2m5.546s
user 2m2.264s
sys 0m1.572s