stats0007 stats0007 - 2 months ago 29
Java Question

java8 parallel string statistics / stream / map/reduce

I have a very long string in java and I am trying to get statistics of that string.

For example String s = "afafaf"

And I want to get all counts for all existing substrings with length 2.
For this small example from above this would be:

"af" - 3
"fa" - 2

Another example:
String s = "hsdjs"

Result:
"hs" - 1
"sd" - 1
"dj" - 1
"js" - 1

What I did and what is working was going over the string with for (int i=0; i < s.length;i++) and iterating Map entries.

Problem is that is damn slow.
I was thinking maybe the new Java8 functions for parallel processing might help me. But unfortuntely I wasn't able to get something running...maybe someone can help me out.

Current code:

import com.google.common.collect.HashMultiset;


String inputString = s;
HashMultiset<String> multi = HashMultiset.create();

for (int i=0;i <inputString.length()-1;i++) {
String aktuellerString = inputString.substring(i, i+2);
multi.add(aktuellerString);
}


Here is the current profiling:
http://fs5.directupload.net/images/160909/naadsfxi.png

The add() method of HashMultiset of the google guava library actually takes most time overall. But this was the fastest collection I could find. (tried several other optimized libraries including normal HashMap, Tie, GS Collections,gnu.trove.map.hash.THashMap;
import org.apache.commons.collections.FastHashMap, ...).

Thats why I thought that parallel processing might be the only way to speed up.

Answer

UPDATE: As Marko points out, the cost of creating sub Strings is significant and even with multiple CPUs you can do better by having a structure which avoids creating them. In this case we have only two characters and these can be encoded as an int value. In this case we could assume ASCII characters.

public static void main(String[] args) throws IOException {
    char[] chars = new char[1000000000];
    Random rand = new Random();
    for (int i = 0; i < chars.length; i++)
        chars[i] = (char) (rand.nextInt(26) + 'a');
    String s = new String(chars);

    long start = System.currentTimeMillis();
    Map<String, Integer> freq = IntStream.range(0, s.length() - 1).parallel()
            .mapToObj(i -> s.substring(i, i + 2))
            .collect(Collectors.groupingBy(w -> w, Collectors.summingInt(e -> 1)));
    long time = System.currentTimeMillis() - start;
    System.out.println("Took " + time + " ms " + freq);
}

prints

Took 8401 ms {aa=1479201, ab=1478451, ac=1479055, ...

However, if we use collect directly we can use a structure which doesn't create any objects.

public static void main(String[] args) throws IOException {
    char[] chars = new char[1000000000];
    Random rand = new Random();
    for (int i = 0; i < chars.length; i++)
        chars[i] = (char) (rand.nextInt(26) + 'a');
    String s = new String(chars);

    long start = System.currentTimeMillis();
    int[] freqArr = IntStream.range(0, s.length() - 1).parallel()
            .collect(() -> new int[128 * 128],
                    (arr, i) -> arr[s.charAt(i) * 128 + s.charAt(i + 1)]++,
                    (a, b) -> sum(a, b));
    Map<String, Integer> freq = new TreeMap<>();
    for (int i = 0; i < freqArr.length; i++) {
        int c = freqArr[i];
        if (c == 0) continue;
        String key = "" + (char) (i >> 7) + (char) (i & 0x7f);
        freq.put(key, c);
    }
    long time = System.currentTimeMillis() - start;
    System.out.println("Took " + time + " ms " + freq);
}

static int[] sum(int[] a, int[] b) {
    for (int i = 0; i < a.length; i++)
        a[i] += b[i];
    return a;
}

prints the following which is ~20x faster.

Took 404 ms {aa=1479575, ab=1480511, ac=1476255, 

This makes a big difference because we are dealing with small strings


You can replace

for (int i=0; i < s.length;i++) { something(i) }

with

IntStream.range(0, s.length()).parallel().forEach(i -> { something(i) })

but a better solution is to use a mapping...

String s = "afafaffafafafffafaaaf";

Map<String, Long> freq = IntStream.range(0, s.length()-1).parallel() // 1
        .mapToObj(i -> s.substring(i, i + 2)) // 2
        .collect(Collectors.groupingBy(w -> w, Collectors.counting())); //3

System.out.println(freq);
  1. Go through all the indexes for which there will be a two character string, in parallel.
  2. Obtain each of the two character Strings for those indexes.
  3. Concurrently group by name, the count of all the number of occurrences.

prints

{ff=3, aa=2, af=8, fa=7}

On Holger's point about groupingByConcurrent being potentially slower I tested four cases.

    long start = System.currentTimeMillis();
    Map<Integer, Long> freq = IntStream.range(0,1000000000)/*.parralel()*/
            .mapToObj(i -> i % 10)
            .collect(Collectors.groupingBy/*Concurrent*/(w -> w, Collectors.counting()));
    long time = System.currentTimeMillis() - start;
    System.out.println("Took " + time+" ms " + freq);

without parallel(), with groupingBy : Took 14156 ms {0=100000000, 1=100000000, 2=100000000, 3=100000000, 4=100000000, 5=100000000, 6=100000000, 7=100000000, 8=100000000, 9=100000000}
with parallel(), with groupingBy : Took 5581 ms {0=100000000, 1=100000000, 2=100000000, 3=100000000, 4=100000000, 5=100000000, 6=100000000, 7=100000000, 8=100000000, 9=100000000}
without parallel(), with groupingByConcurrent : Took 38218 ms {0=100000000, 1=100000000, 2=100000000, 3=100000000, 4=100000000, 5=100000000, 6=100000000, 7=100000000, 8=100000000, 9=100000000}
with parallel(), with groupingByConcurrent : Took 27619 ms {0=100000000, 1=100000000, 2=100000000, 3=100000000, 4=100000000, 5=100000000, 6=100000000, 7=100000000, 8=100000000, 9=100000000}

Using groupingBy was the best solution, parallel or not.

Using Holger's comment further, by using summingInt this proved faster again.

long start = System.currentTimeMillis();
Map<Integer, Integer> freq = IntStream.range(0, 1000000000).parallel()
        .mapToObj(i -> i % 10)
        .collect(Collectors.groupingBy(w -> w, Collectors.summingInt(e -> 1)));
long time = System.currentTimeMillis() - start;
System.out.println("Took " + time+" ms " + freq);

prints

Took 4131 ms {0=100000000, 1=100000000, 2=100000000, 3=100000000, 4=100000000, 5=100000000, 6=100000000, 7=100000000, 8=100000000, 9=100000000}
Comments