accssharma accssharma - 1 month ago 10
Scala Question

Divide elements of column by a sum of elements (of same column) grouped by elements of another column

I have been working on a spark application and was trying to transform a dataframe as shown in table 1. I want to divide each element of a column (_2) by a sum of elements (of same column) grouped by elements of another column (_1). Table 2 is the expected result.

table 1

+---+---+
| _1| _2|
+---+---+
| 0| 13|
| 0| 7|
| 0| 3|
| 0| 1|
| 0| 1|
| 1| 4|
| 1| 8|
| 1| 18|
| 1| 4|
+---+---+


table 2

+---+----+
| _1| _2 |
+---+----+
| 0|13/x|
| 0| 7/x|
| 0| 3/x|
| 0| 1/x|
| 0| 1/x|
| 1| 4/y|
| 1| 8/y|
| 1|18/y|
| 1| 4/y|
+---+----+


where, x= (13+7+3+1+1) and y = (4+8+18+4)

Then, I want to calculate entropy for each element in column _1:
i.e. for each element in column _1 calculate sum(p_i x log(p_i)) in the column _2. Where, p_i's are basically the values in column _2 for each value in column _1 in table 2.

The final output would be.

+---+---------+
| _1| ENTROPY |
+---+---------+
| 0|entropy_1|
| 1|entropy_2|
+---+---------+


How can I implement this in spark (preferably in scala)? What would be the optimized way to perform the above operations? I'm new to scala, any related suggestions will be highly appreciated.

Thank you.

Answer

If you want a concise solution and groups are reasonably small you can use window functions. First you have to define a window:

import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy("_1").rowsBetween(Long.MinValue, Long.MaxValue)

probability:

import org.apache.spark.sql.functions.sum

val p = $"_2" / sum($"_2").over(w)
val withP = df.withColumn("p", p)

and finally the entropy:

import org.apache.spark.sql.functions.log2

withP.groupBy($"_1").agg((-sum($"p" * log2($"p"))).alias("entropy"))

For the example data

val df = Seq(
  (0, 13), (0, 7), (0, 3), (0, 1), (0, 1), (1, 4), (1, 8), (1, 18), (1, 4)).toDF

the result is:

+---+------------------+
| _1|           entropy|
+---+------------------+
|  1|1.7033848993102918|
|  0|1.7433726580786888|
+---+------------------+

If window functions are not acceptable performance wise:

df.groupBy($"_1").agg(sum("_2").alias("total"))
  .join(df, Seq("_1"), "inner")
  .withColumn("p", $"_2" / $"total")
  .groupBy($"_1")
  .agg((-sum($"p" * log2($"p"))).alias("entropy"))
Comments