user1532146 user1532146 - 1 month ago 8
Java Question

How can I optimize my java implementation of Held-Karp algorithm to shorten the running time?

I use Java implemented Held-KarpTSP algorithm algo to solve a 25 cities TSP problem.
The program passes with 4 cities.

When it runs with 25 cities it won't stop for several hours. I use jVisualVM to see what's the hotspot, after some optimization now it shows
98% of time is in real computing instead in Map.contains or Map.get.

So I'd like to have your advice, and here is the code:

private void solve() throws Exception {
long beginTime = System.currentTimeMillis();
int counter = 0;

List<BitSetEndPointID> previousCosts;
List<BitSetEndPointID> currentCosts;
//maximum number of elements is c(n,[n/2])
//To calculate m-set's costs just need to keep (m-1)set's costs
List<BitSetEndPointID> lastKeys = new ArrayList<BitSetEndPointID>();
int m;
if (totalNodes < 10) {
//for test data, generate them on the fly
SetUtil3.generateMSet(totalNodes);
}
//m=1
BitSet beginSet = new BitSet();
beginSet.set(0);
previousCosts = new ArrayList<BitSetEndPointID>(1);
BitSetEndPointID beginner = new BitSetEndPointID(beginSet, 0);
beginner.setCost(0f);
previousCosts.add(beginner);

//for m=2 to totalNodes
for (m = 2; m <= totalNodes; m++) {// sum(m=2..n 's C(n,m)*(m-1)(m-1)) ==> O(n^2 * 2^n)
//pick m elements from total nodes, the element id is the index of nodeCoordinates
// the first node is always present

BitSet[] msets;
if (totalNodes < 10) {
msets = SetUtil3.msets[m - 1];
} else {
//for real data set, will read from serialized file
msets = SetUtil3.getMsets(totalNodes, m-1);
}
currentCosts = new ArrayList<BitSetEndPointID>(msets.length);
//System.out.println(m + " sets' size: " + msets.size());
for (BitSet mset : msets) { //C(n,m) mset
int[] candidates = allSetBits(mset, m);
//mset is a BitSet which makes sure begin point 0 comes first
//so end point candidate begins with 1. candidate[0] is always begin point 0
for (int i = 1; i < candidates.length; i++) { // m-1 bits are set
//set the new last point as j, j must not be the same as begin point 0
int j = candidates[i];
//middleNodes = mset -{j}
BitSet middleNodes = (BitSet) mset.clone();
middleNodes.clear(j);
//loop through all possible points which are second to the last
//and get min(A[S-{j},k] + k->j), k!=j
float min = Float.MAX_VALUE;
int k;
for (int ki = 0; ki < candidates.length; ki++) {// m-1 calculation
k = candidates[ki];
if (k == j) continue;
float middleCost = 0;
BitSetEndPointID key = new BitSetEndPointID(middleNodes, k);
int index = previousCosts.indexOf(key);
if (index != -1) {
//System.out.println("get value from map in m " + m + " y key " + middleNodes);
middleCost = previousCosts.get(index).getCost();
} else if (k == 0 && !middleNodes.equals(beginSet)) {
continue;
} else {
System.out.println("middleCost not found!");
continue;
// System.exit(-1);
}


float lastCost = distances[k][j];
float cost = middleCost + lastCost;
if (cost < min) {
min = cost;
}

counter++;
if (counter % 500000 == 0) {
try {
Thread.currentThread().sleep(100);
} catch (InterruptedException iex) {
System.out.println("Who dares interrupt my precious sleep?!");
}
}
}
//set the costs for chosen mset and last point j
BitSetEndPointID key = new BitSetEndPointID(mset, j);
key.setCost(min);
currentCosts.add(key);

// System.out.println("===========================================>mset " + mset + " and end at " +
// j + " 's min cost: " + min);
// if (m == totalNodes) {
// lastKeys.add(key);
// }
}
}
previousCosts = currentCosts;
System.out.println("...");
}

calcLastStop(lastKeys, previousCosts);
System.out.println(" cost " + (System.currentTimeMillis() - beginTime) / 60000 + " minutes.");
}


private void calcLastStop(List<BitSetEndPointID> lastKeys, List<BitSetEndPointID> costs) {
//last step, calculate the min(A[S={1..n},k] +k->1)
float finalMinimum = Float.MAX_VALUE;
for (BitSetEndPointID key : costs) {
float middleCost = key.getCost();
Integer endPoint = key.lastPointID;
float lastCost = distances[endPoint][0];
float cost = middleCost + lastCost;
if (cost < finalMinimum) {
finalMinimum = cost;
}
}
System.out.println("final result: " + finalMinimum);
}

Answer

You can speed up your code by using arrays of primitives (it's likely to have to better memory layout than a list of objects) and operating on bitmasks directly (without bitsets or other objects). Here is some code (it generates a random graph but you can easily change it so that it reads your graph):

import java.io.*;
import java.util.*;

class Main {

    final static float INF = 1e10f;

    public static void main(String[] args) {
        final int n = 25;
        float[][] dist = new float[n][n];
        Random random = new Random();
        for (int i = 0; i < n; i++)
            for (int j = i + 1; j < n; j++)
                dist[i][j] = dist[j][i] = random.nextFloat();
        float[][] dp = new float[n][1 << n];
        for (int i = 0; i < dp.length; i++)
            Arrays.fill(dp[i], INF);
        dp[0][1] = 0.0f;
        for (int mask = 1; mask < (1 << n); mask++) {
            for (int lastNode = 0; lastNode < n; lastNode++) {
                if ((mask & (1 << lastNode)) == 0)
                    continue; 
                for (int nextNode = 0; nextNode < n; nextNode++) {
                    if ((mask & (1 << nextNode)) != 0)
                        continue;
                    dp[nextNode][mask | (1 << nextNode)] = Math.min(
                            dp[nextNode][mask | (1 << nextNode)],
                            dp[lastNode][mask] + dist[lastNode][nextNode]);
                }
            }   
        }
        double res = INF;
        for (int lastNode = 0; lastNode < n; lastNode++)
            res = Math.min(res, dist[lastNode][0] + dp[lastNode][(1 << n) - 1]);
        System.out.println(res);
    }
}

It takes only a couple of minutes to complete on my computer:

time java Main
...
real    2m5.546s
user    2m2.264s
sys     0m1.572s
Comments