alexander ostrikov alexander ostrikov - 5 months ago 7
Python Question

Spark iteration time increasing exponentially when using join

I'm quite new to Spark and I'm trying to implement some iterative algorithm for clustering (expectation-maximization) with centroid represented by Markov model. So I need to do iterations and joins.

One problem that I experience is that each iterations time growth exponentially.

After some experimenting I found that when doing iterations it's needed to persist RDD that is going to be reused in the next iteration, otherwise every iteration spark will create execution plan that will recalculate the RDD from from start, thus increasing calculation time.

init = sc.parallelize(xrange(10000000), 3)
init.cache()

for i in range(6):
print i
start = datetime.datetime.now()

init2 = init.map(lambda n: (n, n*3))
init = init2.map(lambda n: n[0])
# init.cache()

print init.count()
print str(datetime.datetime.now() - start)


Results in:

0
10000000
0:00:04.283652
1
10000000
0:00:05.998830
2
10000000
0:00:08.771984
3
10000000
0:00:11.399581
4
10000000
0:00:14.206069
5
10000000
0:00:16.856993


So adding cache() helps and iteration time become constant.

init = sc.parallelize(xrange(10000000), 3)
init.cache()

for i in range(6):
print i
start = datetime.datetime.now()

init2 = init.map(lambda n: (n, n*3))
init = init2.map(lambda n: n[0])
init.cache()

print init.count()
print str(datetime.datetime.now() - start)
0
10000000
0:00:04.966835
1
10000000
0:00:04.609885
2
10000000
0:00:04.324358
3
10000000
0:00:04.248709
4
10000000
0:00:04.218724
5
10000000
0:00:04.223368


But when making Join inside the iteration the problem comes back.
Here is some simple code I demonstrating the problem. Even making cache on each RDD transformation doesn't solve the problem:

init = sc.parallelize(xrange(10000), 3)
init.cache()

for i in range(6):
print i
start = datetime.datetime.now()

init2 = init.map(lambda n: (n, n*3))
init2.cache()

init3 = init.map(lambda n: (n, n*2))
init3.cache()

init4 = init2.join(init3)
init4.count()
init4.cache()

init = init4.map(lambda n: n[0])
init.cache()

print init.count()
print str(datetime.datetime.now() - start)


And here is the output. As you can see iteration time growing exponentially :(

0
10000
0:00:00.674115
1
10000
0:00:00.833377
2
10000
0:00:01.525314
3
10000
0:00:04.194715
4
10000
0:00:08.139040
5
10000
0:00:17.852815


I will really appreciate any help :)

Answer

Summary:

Generally speaking iterative algorithms, especially ones with self-join or self-union, require a control over:

Problem described here is a result of the lack of the former one. In each iteration number of partition increases with self-join leading to exponential pattern. To address that you have to either control number of partitions in each iteration (see below) or use global tools like spark.default.parallelism (see an answer provided by Travis). In general the first approach provides much more control in general and doesn't affect other parts of code.

Original answer:

As far as I can tell there are two interleaved problems here - growing number of partitions and shuffling overhead during joins. Both can be easily handled so lets go step by step.

First lets create a helper to collect the statistics:

import datetime

def get_stats(i, init, init2, init3, init4,
       start, end, desc, cache, part, hashp):
    return {
        "i": i,
        "init": init.getNumPartitions(),
        "init1": init2.getNumPartitions(),
        "init2": init3.getNumPartitions(),
        "init4": init4.getNumPartitions(),
        "time": str(end - start),
        "timen": (end - start).seconds + (end - start).microseconds * 10 **-6,
        "desc": desc,
        "cache": cache,
        "part": part,
        "hashp": hashp
    }

another helper to handle caching/partitioning

def procRDD(rdd, cache=True, part=False, hashp=False, npart=16):
    rdd = rdd if not part else rdd.repartition(npart)
    rdd = rdd if not hashp else rdd.partitionBy(npart)
    return rdd if not cache else rdd.cache()

extract pipeline logic:

def run(init, description, cache=True, part=False, hashp=False, 
    npart=16, n=6):
    times = []

    for i in range(n):
        start = datetime.datetime.now()

        init2 = procRDD(
                init.map(lambda n: (n, n*3)),
                cache, part, hashp, npart)
        init3 = procRDD(
                init.map(lambda n: (n, n*2)),
                cache, part, hashp, npart)


        # If part set to True limit number of the output partitions
        init4 = init2.join(init3, npart) if part else init2.join(init3) 
        init = init4.map(lambda n: n[0])

        if cache:
            init4.cache()
            init.cache()

        init.count() # Force computations to get time
        end = datetime.datetime.now() 

        times.append(get_stats(
            i, init, init2, init3, init4,
            start, end, description,
            cache, part, hashp
        ))

    return times

and create initial data:

ncores = 8
init = sc.parallelize(xrange(10000), ncores * 2).cache()

Join operation by itself, if numPartitions argument is not provided, adjust number of partitions in the output based on the number of partitions of the input RDDs. It means growing number of partitions with each iteration. If number of partitions is to large things get ugly. You can deal with these by providing numPartitions argument for join or repartition RDDs with each iteration.

timesCachePart = sqlContext.createDataFrame(
        run(init, "cache + partition", True, True, False, ncores * 2))
timesCachePart.select("i", "init1", "init2", "init4", "time", "desc").show()

+-+-----+-----+-----+--------------+-----------------+
|i|init1|init2|init4|          time|             desc|
+-+-----+-----+-----+--------------+-----------------+
|0|   16|   16|   16|0:00:01.145625|cache + partition|
|1|   16|   16|   16|0:00:01.090468|cache + partition|
|2|   16|   16|   16|0:00:01.059316|cache + partition|
|3|   16|   16|   16|0:00:01.029544|cache + partition|
|4|   16|   16|   16|0:00:01.033493|cache + partition|
|5|   16|   16|   16|0:00:01.007598|cache + partition|
+-+-----+-----+-----+--------------+-----------------+

As you can see when we repartition execution time is more or less constant. The second problem is that above data is partitioned randomly. To ensure join performance we would like to have same keys on a single partition. To achieve that we can use hash partitioner:

timesCacheHashPart = sqlContext.createDataFrame(
    run(init, "cache + hashpart", True, True, True, ncores * 2))
timesCacheHashPart.select("i", "init1", "init2", "init4", "time", "desc").show()

+-+-----+-----+-----+--------------+----------------+
|i|init1|init2|init4|          time|            desc|
+-+-----+-----+-----+--------------+----------------+
|0|   16|   16|   16|0:00:00.946379|cache + hashpart|
|1|   16|   16|   16|0:00:00.966519|cache + hashpart|
|2|   16|   16|   16|0:00:00.945501|cache + hashpart|
|3|   16|   16|   16|0:00:00.986777|cache + hashpart|
|4|   16|   16|   16|0:00:00.960989|cache + hashpart|
|5|   16|   16|   16|0:00:01.026648|cache + hashpart|
+-+-----+-----+-----+--------------+----------------+

Execution time is constant as before and There is a small improvement over the basic partitioning.

Now lets use cache only as a reference:

timesCacheOnly = sqlContext.createDataFrame(
    run(init, "cache-only", True, False, False, ncores * 2))
timesCacheOnly.select("i", "init1", "init2", "init4", "time", "desc").show()


+-+-----+-----+-----+--------------+----------+
|i|init1|init2|init4|          time|      desc|
+-+-----+-----+-----+--------------+----------+
|0|   16|   16|   32|0:00:00.992865|cache-only|
|1|   32|   32|   64|0:00:01.766940|cache-only|
|2|   64|   64|  128|0:00:03.675924|cache-only|
|3|  128|  128|  256|0:00:06.477492|cache-only|
|4|  256|  256|  512|0:00:11.929242|cache-only|
|5|  512|  512| 1024|0:00:23.284508|cache-only|
+-+-----+-----+-----+--------------+----------+

As you can see number of partitions (init2, init3, init4) for cache-only version doubles with each iteration and execution time is proportional to the number of partitions.

Finally we can check if we can improve performance with large number of partitions if we use hash partitioner:

timesCacheHashPart512 = sqlContext.createDataFrame(
    run(init, "cache + hashpart 512", True, True, True, 512))
timesCacheHashPart512.select(
    "i", "init1", "init2", "init4", "time", "desc").show()
+-+-----+-----+-----+--------------+--------------------+
|i|init1|init2|init4|          time|                desc|
+-+-----+-----+-----+--------------+--------------------+
|0|  512|  512|  512|0:00:14.492690|cache + hashpart 512|
|1|  512|  512|  512|0:00:20.215408|cache + hashpart 512|
|2|  512|  512|  512|0:00:20.408070|cache + hashpart 512|
|3|  512|  512|  512|0:00:20.390267|cache + hashpart 512|
|4|  512|  512|  512|0:00:20.362354|cache + hashpart 512|
|5|  512|  512|  512|0:00:19.878525|cache + hashpart 512|
+-+-----+-----+-----+--------------+--------------------+

Improvement is not so impressive but if you have a small cluster and a lot of data it is still worth trying.

I guess take away message here is partitioning matters. There are contexts where it is handled for you (mllib, sql) but if you use low level operations it is your responsibility.