jimpudar jimpudar - 3 months ago 10
Java Question

How can I build this tree with O(n) space complexity?

The Problem



Given a set of integers, find a subset of those integers which sum to 100,000,000.

Solution



I am attempting to build a tree containing all the combinations of the given set along with the sum. For example, if the given set looked like
0,1,2
, I would build the following tree, checking the sum at each node:

{}
{} {0}
{} {1} {0} {0,1}
{} {2} {1} {1,2} {0} {2} {0,1} {0,1,2}


Since I keep both the array of integers at each node and the sum, I should only need the bottom (current) level of the tree in memory.

Issues



My current implementation will maintain the entire tree in memory and therefore uses way too much heap space.

How can I change my current implementation so that the GC will take care of my upper tree levels?

(At the moment I am just throwing a RuntimeException when I have found the target sum but this is obviously just for playing around)

public class RecursiveSolver {
static final int target = 100000000;
static final int[] set = new int[]{98374328, 234234123, 2341234, 123412344, etc...};

Tree initTree() {
return nextLevel(new Tree(null), 0);
}

Tree nextLevel(Tree currentLocation, int current) {
if (current == set.length) { return null; }
else if (currentLocation.sum == target) throw new RuntimeException(currentLocation.getText());
else {
currentLocation.left = nextLevel(currentLocation.copy(), current + 1);
Tree right = currentLocation.copy();
right.value = add(currentLocation.value, set[current]);
right.sum = currentLocation.sum + set[current];
currentLocation.right = nextLevel(right, current + 1);
return currentLocation;
}
}

int[] add(int[] array, int digit) {
if (array == null) {
return new int[]{digit};
}
int[] newValue = new int[array.length + 1];
for (int i = 0; i < array.length; i++) {
newValue[i] = array[i];
}
newValue[array.length] = digit;
return newValue;
}

public static void main(String[] args) {
RecursiveSolver rs = new RecursiveSolver();
Tree subsetTree = rs.initTree();
}
}

class Tree {
Tree left;
Tree right;
int[] value;
int sum;

Tree(int[] value) {
left = null;
right = null;
sum = 0;
this.value = value;
if (value != null) {
for (int i = 0; i < value.length; i++) sum += value[i];
}
}

Tree copy() {
return new Tree(this.value);
}
}

Answer

After thinking more about erip's comments, I realized he is correct - I shouldn't be using a tree to implement this algorithm.

Brute force usually is O(n*2^n) because there are n additions for 2^n subsets. Because I only do one addition per node, the solution I came up with is O(2^n) where n is the size of the given set. Also, this algorithm is only O(n) space complexity. Since the number of elements in the original set in my particular problem is small (around 25) O(2^n) complexity is not too much of a problem.

The dynamic solution to this problem is O(t*n) where t is the target sum and n is the number of elements. Because t is very large in my problem, the dynamic solution ends up with a very long runtime and a high memory usage.

This completes my particular solution in around 311 ms on my machine, which is a tremendous improvement over the dynamic programming solutions I have seen for this particular class of problem.

public class TailRecursiveSolver {
    public static void main(String[] args) {
        final long starttime = System.currentTimeMillis();
        try {
            step(new Subset(null, 0), 0);
        }
        catch (RuntimeException ex) {
            System.out.println(ex.getMessage());
            final long endtime = System.currentTimeMillis();
            System.out.println(endtime - starttime);
        }
    }

    static final int target = 100000000;
    static final int[] set = new int[]{ . . . };

    static void step(Subset current, int counter) {
        if (current.sum == target) throw new RuntimeException(current.getText());
        else if (counter == set.length) {}
        else {
            step(new Subset(add(current.subset, set[counter]), current.sum + set[counter]), counter + 1);
            step(current, counter + 1);
        }
    }

    static int[] add(int[] array, int digit) {
        if (array == null) {
            return new int[]{digit};
        }
        int[] newValue = new int[array.length + 1];
        for (int i = 0; i < array.length; i++) {
            newValue[i] = array[i];
        }
        newValue[array.length] = digit;
        return newValue;
    }
}

class Subset {
    int[] subset;
    int sum;

    Subset(int[] subset, int sum) {
        this.subset = subset;
        this.sum = sum;
    }

    public String getText() {
        String ret = "";
        for (int i = 0; i < (subset == null ? 0 : subset.length); i++) {
            ret += " + " + subset[i];
        }
        if (ret.startsWith(" ")) {
            ret = ret.substring(3);
            ret = ret + " = " + sum;
        } else ret = "null";
        return ret;
    }
}

EDIT -

The above code still runs in O(n*2^n) time - since the add method runs in O(n) time. This following code will run in true O(2^n) time, and is MUCH more performant, completing in around 20 ms on my machine.

It is limited to sets less than 64 elements due to storing the current subset as the bits in a long.

public class SubsetSumSolver {
    static boolean found = false;
    static final int target = 100000000;
    static final int[] set = new int[]{ . . . };

    public static void main(String[] args) {
        step(0,0,0);
    }

    static void step(long subset, int sum, int counter) {
        if (sum == target) {
            found = true;
            System.out.println(getText(subset, sum));
        }
        else if (!found && counter != set.length) {
            step(subset + (1 << counter), sum + set[counter], counter + 1);
            step(subset, sum, counter + 1);
        }
    }

    static String getText(long subset, int sum) {
        String ret = "";
        for (int i = 0; i < 64; i++) if((1 & (subset >> i)) == 1) ret += " + " + set[i];
        if (ret.startsWith(" ")) ret = ret.substring(3) + " = " + sum;
        else ret = "null";
        return ret;
    }
}

EDIT 2 -

Here is another version uses a meet in the middle attack, along with a little bit shifting in order to reduce the complexity from O(2^n) to O(2^(n/2)).

If you want to use this for sets with between 32 and 64 elements, you should change the int which represents the current subset in the step function to a long although performance will obviously drastically decrease as the set size increases. If you want to use this for a set with odd number of elements, you should add a 0 to the set to make it even numbered.

import java.util.ArrayList;
import java.util.List;

public class SubsetSumMiddleAttack {
    static final int target = 100000000;
    static final int[] set = new int[]{ ... };

    static List<Subset> evens = new ArrayList<>();
    static List<Subset> odds = new ArrayList<>();

    static int[][] split(int[] superSet) {
        int[][] ret = new int[2][superSet.length / 2]; 

        for (int i = 0; i < superSet.length; i++) ret[i % 2][i / 2] = superSet[i];

        return ret;
    }

    static void step(int[] superSet, List<Subset> accumulator, int subset, int sum, int counter) {
        accumulator.add(new Subset(subset, sum));
        if (counter != superSet.length) {
            step(superSet, accumulator, subset + (1 << counter), sum + superSet[counter], counter + 1);
            step(superSet, accumulator, subset, sum, counter + 1);
        }
    }

    static void printSubset(Subset e, Subset o) {
        String ret = "";
        for (int i = 0; i < 32; i++) {
            if (i % 2 == 0) {
                if ((1 & (e.subset >> (i / 2))) == 1) ret += " + " + set[i];
            }
            else {
                if ((1 & (o.subset >> (i / 2))) == 1) ret += " + " + set[i];
            }
        }
        if (ret.startsWith(" ")) ret = ret.substring(3) + " = " + (e.sum + o.sum);
        System.out.println(ret);
    }

    public static void main(String[] args) {
        int[][] superSets = split(set);

        step(superSets[0], evens, 0,0,0);
        step(superSets[1], odds, 0,0,0);

        for (Subset e : evens) {
            for (Subset o : odds) {
                if (e.sum + o.sum == target) printSubset(e, o);
            }
        }
    }
}

class Subset {
    int subset;
    int sum;

    Subset(int subset, int sum) {
        this.subset = subset;
        this.sum = sum;
    }
}