view raw
MetallicPriest MetallicPriest - 6 months ago 153
Python Question

How does the pyspark mapPartitions function work?

So I am trying to learn Spark using Python (Pyspark). I want to know how the function

work. That is what Input it takes and what Output it gives. I couldn't find any proper example from the internet. Lets say, I have an RDD object containing lists, such as below.

[ [1, 2, 3], [3, 2, 4], [5, 2, 7] ]

And I want to remove element 2 from all the lists, how would I achieve that using


mapPartition should be thought of as a map operation over partitions and not over the elements of the partition. It's input is the set of current partitions its output will be another set of partitions.

The function you pass map must take an individual element of your RDD

The function you pass mapPartition must take an iterable of your RDD type and return and iterable of some other or the same type.

In your case you probably just want to do something like

def filterOut2(line):
    return [x for x in line if x != 2]

filtered_lists =

if you wanted to use mapPartition it would be

def filterOut2FromPartion(list_of_lists):
  final_iterator = []
  for sub_list in list_of_lists:
    final_iterator.append( [x for x in sub_list if x != 2])
  return iter(final_iterator)

filtered_lists = data.mapPartition(filterOut2FromPartion)