hyperc54 hyperc54 - 15 days ago 4
Python Question

Pyspark : Custom window function

I am currently trying to extract series of consecutive occurences in a Pyspark dataframe and order/rank them as shown below (for convenience I have ordered the initial dataframe by user_id and timestamp):

df_ini
+-------+--------------------+------------+
|user_id| timestamp | actions |
+-------+--------------------+------------+
| 217498| 100000001| 'A' |
| 217498| 100000025| 'A' |
| 217498| 100000124| 'A' |
| 217498| 100000152| 'B' |
| 217498| 100000165| 'C' |
| 217498| 100000177| 'C' |
| 217498| 100000182| 'A' |
| 217498| 100000197| 'B' |
| 217498| 100000210| 'B' |
| 854123| 100000005| 'A' |
| 854123| 100000007| 'A' |
| etc.


to :

expected df_transformed
+-------+------------+------------+------------+
|user_id| actions | nb_of_occ | order |
+-------+------------+------------+------------+
| 217498| 'A' | 3 | 1 |
| 217498| 'B' | 1 | 2 |
| 217498| 'C' | 2 | 3 |
| 217498| 'A' | 1 | 4 |
| 217498| 'B' | 2 | 5 |
| 854123| 'A' | 2 | 1 |
| etc.


My guess is that I have to use a smart window function that partition the table by user_id and actions but only when these actions are consecutive in time ! Which I can't figure how to do...

If someone encountered this type of transformation in pyspark before I'd be glad to get a hint !

Cheers

Answer

This is a pretty common pattern and can be expressed using window functions in a few steps. First import required functions:

from pyspark.sql.functions import sum as sum_, lag, col, coalesce, lit
from pyspark.sql.window import Window

Next define a window:

w = Window.partitionBy("user_id").orderBy("timestamp")'

Mark first row for each group:

is_first = coalesce(
  (lag("actions", 1).over(w) != col("actions")).cast("bigint"),
  lit(1)
)

Define order:

order = sum_("is_first").over(w)

And combine all part together with an aggregation:

(df
    .withColumn("is_first", is_first)
    .withColumn("order", order)
    .groupBy("user_id", "actions", "order")
    .count())

If you define df as:

df = sc.parallelize([
    (217498, 100000001, 'A'), (217498, 100000025, 'A'), (217498, 100000124, 'A'),
    (217498, 100000152, 'B'), (217498, 100000165, 'C'), (217498, 100000177, 'C'),
    (217498, 100000182, 'A'), (217498, 100000197, 'B'), (217498, 100000210, 'B'),
    (854123, 100000005, 'A'), (854123, 100000007, 'A')
]).toDF(["user_id", "timestamp", "actions"])

and order the result by user_id and order you'll get:

+-------+-------+-----+-----+ 
|user_id|actions|order|count|
+-------+-------+-----+-----+
| 217498|      A|    1|    3|
| 217498|      B|    2|    1|
| 217498|      C|    3|    2|
| 217498|      A|    4|    1|
| 217498|      B|    5|    2|
| 854123|      A|    1|    2|
+-------+-------+-----+-----+
Comments