solistice - 5 months ago 105

Scala Question

I've been trying to find a way to count the number of times sets of Strings occur in a transaction database (implementing the Apriori algorithm in a distributed fashion). The code I have currently is as follows:

`val cand_br = sc.broadcast(cand)`

transactions.flatMap(trans => freq(trans, cand_br.value))

.reduceByKey(_ + _)

}

def freq(trans: Set[String], cand: Array[Set[String]]) : Array[(Set[String],Int)] = {

var res = ArrayBuffer[(Set[String],Int)]()

for (c <- cand) {

if (c.subsetOf(trans)) {

res += ((c,1))

}

}

return res.toArray

}

transactions starts out as an

`RDD[Set[String]]`

`RDD[(K, V)`

`cand`

`cand`

When watching performance on the UI, the flatMap stage quickly takes about 3min to finish, whereas the rest takes < 1ms.

`transactions.count() ~= 88000`

`cand.length ~= 24000`

Is there a more optimal solution to solve this subproblem?

PS: I'm fairly new to Scala / Spark framework, so there might be some strange constructions in this code

Answer

Probably, the right question to ask in this case would be: "what is the time complexity of this algorithm". I think it is very much unrelated to Spark's flatMap operation.

Given 2 collections of Sets of size `m`

and `n`

, this algorithm is counting how many elements of one collection are a subset of elements of the other collection, so it looks like complexity `m x n`

. Looking one level deeper, we also see that 'subsetOf' is linear of the number of elements of the subset. `x subSet y`

== `x forAll y`

, so actually the complexity is `m x n x s`

where `s`

is the cardinality of the subsets being checked.

In other words, this `flatMap`

operation has a lot of work to do.

Now, going back to Spark, we can also observe that this algo is *embarrassingly parallel* and we can take advantage of Spark's capabilities to our advantage.

To compare some approaches, I loaded the 'retail' dataset [1] and ran the algo on `val cand = transactions.filter(_.size<4).collect`

. Data size is a close neighbor of the question:

- Transactions.count = 88162
- cand.size = 15451

Some comparative runs on local mode:

- Vainilla:
*1.2 minutes* - Increase
`transactions`

partitions up to # of cores (8):*33 secs*

I also tried an alternative implementation, using `cartesian`

instead of `flatmap`

:

```
transactions.cartesian(candRDD).map{case (tx,cd) => (cd,if (cd.subsetOf(tx)) 1 else 0)}.reduceByKey(_ + _).collect
```

But that resulted in much longer runs as seen in the top 2 lines of the Spark UI (cartesian and cartesian with a higher number of partitions): *2.5 min*

Given I only have 8 logical cores available, going above that does not help.

Is there any added 'Spark flatMap time complexity'? Probably some, as it involves serializing closures and unpacking collections, but negligible in comparison with the function being executed.

Let's see if we can do a better job: I implemented the same algo using plain scala:

val resLocal = reduceByKey(transLocal.flatMap(trans => freq(trans, cand)))

Where the `reduceByKey`

operation is a naive implementation taken from [2]
Execution time: 3.67 seconds.
Sparks gives you parallelism out of the box. This impl is totally sequential and therefore takes longer to complete.

Last sanity check: A trivial flatmap operation:

```
transactions.flatMap(trans => Seq((trans, 1))).reduceByKey( _ + _).collect
```

Execution time: *0.88 secs*

Spark is buying you parallelism and clustering and this algo can take advantage of it. Use more cores and partition the input data accordingly.
There's nothing wrong with `flatmap`

. The time complexity prize goes to the function inside it.