Dang Manh Truong Dang Manh Truong - 2 months ago 10
Java Question

Java parallelization using parallelStream and forEach where each part is completely separate

I'm writing (online machine learning) code that does the following:

  • You have data and labels (there are N elements in the data)

  • There are K base classifiers: classifiers[1..K]

  • For each element (consecutively :for t = 1 to N, since this is online learning), transform the t-th element in the data into new_data_elem[1..K]. This is the tricky part, where I've decided to use parallelStream()

  • Then for each j-th classifier, decide the labels: labels[1..K] (parallel also)

  • Vote the results of labels[1..K] into a single decided_label, and compare it with true_label

  • Then for j = 1:K, you apply classifiers[j].update(new_data_elem[j],decided_label == true_label) (the updates need to know whether the prediction was correct) (again, parallel)

Here is my (pseudo)code. I've checked and benchmarked and it got the same results as the sequential version, and I've seen considerable speedup using it, but I'm not sure if there are potential bugs in it:

Set<Integer> set_of_index = new HashSet<>();
for (int j = 0; j < K; j++){
set_of_index.add(j); // For parallelization
for (int t = 0; t < n; t++){
true_label = true_labels[t];
... // make new_data_elem[1..K]
// Predict
labels[j] = classifiers[j].predict(new_data_elem[j]);
... // using labels[j] to predict decided_label
// Update
classifiers[j].update((new_data_elem[j],decided_label == true_label);

Please check if this is indeed correct, because I've read from: http://docs.oracle.com/javase/8/docs/api/java/util/stream/package-summary.html and it says that

A small number of stream operations, such as forEach() and peek(), can operate only via side-effects; these should be used with care.

, so I'm not too sure :(

Answer Source

Modifying an array element does not interfere with a concurrent modification of a different array element, so, assuming that you calculations do not interfere, the code seems to be correct.

But there are dedicated API methods making the code even cleaner:

for (int t = 0; t < n; t++){
        true_label = true_labels[t];
        // make new_data_elem[1..K] 
        Arrays.parallelSetAll(labels, j -> classifiers[j].predict(new_data_elem[j]));
        // using labels[j] to predict decided_label        
        IntStream.range(0, K).parallel().forEach( j ->
            classifiers[j].update(new_data_elem[j], decided_label == true_label));

Arrays.parallelSetAll is designed to write to each array element and makes it easier to reason that there will be no interference between these array writes. By using IntStream.range(0, K), you get rid of the set_of_index entirely.