solistice solistice - 1 year ago 223
Scala Question

Apache Spark flatMap time complexity

I've been trying to find a way to count the number of times sets of Strings occur in a transaction database (implementing the Apriori algorithm in a distributed fashion). The code I have currently is as follows:

val cand_br = sc.broadcast(cand)
transactions.flatMap(trans => freq(trans, cand_br.value))
.reduceByKey(_ + _)

def freq(trans: Set[String], cand: Array[Set[String]]) : Array[(Set[String],Int)] = {
var res = ArrayBuffer[(Set[String],Int)]()
for (c <- cand) {
if (c.subsetOf(trans)) {
res += ((c,1))
return res.toArray

transactions starts out as an
, and I'm trying to convert it to an
RDD[(K, V)
, with K every element in
and V the number of occurrences of each element in
in the transaction list.

When watching performance on the UI, the flatMap stage quickly takes about 3min to finish, whereas the rest takes < 1ms.

transactions.count() ~= 88000
cand.length ~= 24000
for an idea of the data I'm dealing with. I've tried different ways of persisting the data, but I'm pretty positive that it's an algorithmic problem I am faced with.

Is there a more optimal solution to solve this subproblem?

PS: I'm fairly new to Scala / Spark framework, so there might be some strange constructions in this code

Answer Source

Probably, the right question to ask in this case would be: "what is the time complexity of this algorithm". I think it is very much unrelated to Spark's flatMap operation.

Rough O-complexity analysis

Given 2 collections of Sets of size m and n, this algorithm is counting how many elements of one collection are a subset of elements of the other collection, so it looks like complexity m x n. Looking one level deeper, we also see that 'subsetOf' is linear of the number of elements of the subset. x subSet y == x forAll y, so actually the complexity is m x n x s where s is the cardinality of the subsets being checked.

In other words, this flatMap operation has a lot of work to do.

Going Parallel

Now, going back to Spark, we can also observe that this algo is embarrassingly parallel and we can take advantage of Spark's capabilities to our advantage.

To compare some approaches, I loaded the 'retail' dataset [1] and ran the algo on val cand = transactions.filter(_.size<4).collect. Data size is a close neighbor of the question:

  • Transactions.count = 88162
  • cand.size = 15451

Some comparative runs on local mode:

  • Vainilla: 1.2 minutes
  • Increase transactions partitions up to # of cores (8): 33 secs

I also tried an alternative implementation, using cartesian instead of flatmap:

 transactions.cartesian(candRDD).map{case (tx,cd) => (cd,if  (cd.subsetOf(tx)) 1 else 0)}.reduceByKey(_ + _).collect

But that resulted in much longer runs as seen in the top 2 lines of the Spark UI (cartesian and cartesian with a higher number of partitions): 2.5 min

Given I only have 8 logical cores available, going above that does not help.

enter image description here

Sanity checks:

Is there any added 'Spark flatMap time complexity'? Probably some, as it involves serializing closures and unpacking collections, but negligible in comparison with the function being executed.

Let's see if we can do a better job: I implemented the same algo using plain scala:

val resLocal = reduceByKey(transLocal.flatMap(trans => freq(trans, cand)))

Where the reduceByKey operation is a naive implementation taken from [2] Execution time: 3.67 seconds. Sparks gives you parallelism out of the box. This impl is totally sequential and therefore takes longer to complete.

Last sanity check: A trivial flatmap operation:

transactions.flatMap(trans => Seq((trans, 1))).reduceByKey( _ + _).collect

Execution time: 0.88 secs


Spark is buying you parallelism and clustering and this algo can take advantage of it. Use more cores and partition the input data accordingly. There's nothing wrong with flatmap. The time complexity prize goes to the function inside it.