Jason Jason - 6 months ago 103
Python Question

Reshaping/Pivoting data in Spark RDD and/or Spark DataFrames

I have some data in the following format (either RDD or Spark DataFrame):

from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)

rdd = sc.parallelize([('X01',41,'US',3),
('X01',41,'UK',1),
('X01',41,'CA',2),
('X02',72,'US',4),
('X02',72,'UK',6),
('X02',72,'CA',7),
('X02',72,'XX',8)])

# convert to a Spark DataFrame
schema = StructType([StructField('ID', StringType(), True),
StructField('Age', IntegerType(), True),
StructField('Country', StringType(), True),
StructField('Score', IntegerType(), True)])

df = sqlContext.createDataFrame(rdd, schema)


What I would like to do is to 'reshape' the data, convert certain rows in Country(specifically US, UK and CA) into columns:

ID Age US UK CA
'X01' 41 3 1 2
'X02' 72 4 6 7


Essentially, I need something along the lines of Python's
pivot
workflow:

categories = ['US', 'UK', 'CA']
new_df = df[df['Country'].isin(categories)].pivot(index = 'ID',
columns = 'Country',
values = 'Score')


My dataset is rather large so I can't really
collect()
and ingest the data into memory to do the reshaping in Python itself. Is there a way to convert Python's
.pivot()
into an invokable function while mapping either an RDD or a Spark DataFrame? Any help would be appreciated!

Answer

First up, this is probably not a good idea, because you are not getting any extra information, but you are binding yourself with a fixed schema (ie you must need to know how many countries you are expecting, and of course, additional country means change in code)

Having said that, this is a SQL problem, which is shown below. But in case you suppose it is not too "software like" (seriously, I have heard this!!), then you can refer the first solution.

Solution 1:

def reshape(t):
    out = []
    out.append(t[0])
    out.append(t[1])
    for v in brc.value:
        if t[2] == v:
            out.append(t[3])
        else:
            out.append(0)
    return (out[0],out[1]),(out[2],out[3],out[4],out[5])
def cntryFilter(t):
    if t[2] in brc.value:
        return t
    else:
        pass

def addtup(t1,t2):
    j=()
    for k,v in enumerate(t1):
        j=j+(t1[k]+t2[k],)
    return j

def seq(tIntrm,tNext):
    return addtup(tIntrm,tNext)

def comb(tP,tF):
    return addtup(tP,tF)


countries = ['CA', 'UK', 'US', 'XX']
brc = sc.broadcast(countries)
reshaped = calls.filter(cntryFilter).map(reshape)
pivot = reshaped.aggregateByKey((0,0,0,0),seq,comb,1)
for i in pivot.collect():
    print i

Now, Solution 2: Of course better as SQL is right tool for this

callRow = calls.map(lambda t:   

Row(userid=t[0],age=int(t[1]),country=t[2],nbrCalls=t[3]))
callsDF = ssc.createDataFrame(callRow)
callsDF.printSchema()
callsDF.registerTempTable("calls")
res = ssc.sql("select userid,age,max(ca),max(uk),max(us),max(xx)\
                    from (select userid,age,\
                                  case when country='CA' then nbrCalls else 0 end ca,\
                                  case when country='UK' then nbrCalls else 0 end uk,\
                                  case when country='US' then nbrCalls else 0 end us,\
                                  case when country='XX' then nbrCalls else 0 end xx \
                             from calls) x \
                     group by userid,age")
res.show()

data set up:

data=[('X01',41,'US',3),('X01',41,'UK',1),('X01',41,'CA',2),('X02',72,'US',4),('X02',72,'UK',6),('X02',72,'CA',7),('X02',72,'XX',8)]
 calls = sc.parallelize(data,1)
countries = ['CA', 'UK', 'US', 'XX']

Result:

From 1st solution

(('X02', 72), (7, 6, 4, 8)) 
(('X01', 41), (2, 1, 3, 0))

From 2nd solution:

root  |-- age: long (nullable = true)  
      |-- country: string (nullable = true)  
      |-- nbrCalls: long (nullable = true)  
      |-- userid: string (nullable = true)

userid age ca uk us xx 
 X02    72  7  6  4  8  
 X01    41  2  1  3  0

Kindly let me know if this works, or not :)

Best Ayan

Comments