sushant-hiray sushant-hiray - 2 months ago 12
Scala Question

Spark: Efficient way to get top K frequent values per key in (key, value) RDD?

I've an RDD of (key, value) pairs. I need to fetch top k values according to their frequencies for each key.

I understand that the best way to do this would be using combineByKey.

Currently here is what my combineByKey combinators look like

object TopKCount {
//TopK Count combiners
val k: Int = 10
def createCombiner(value: String): Map[String, Long] = {
Map(value -> 1L)
def mergeValue(combined: Map[String, Long], value: String): Map[String, Long] = {
combined ++ Map(value -> (combined.getOrElse(value, 0L) + 1L))
def mergeCombiners(combined1: Map[String, Long], combined2: Map[String, Long]): Map[String, Long] = {
val top10Keys1 = combined1.toList.sortBy(_._2).takeRight(k).toMap.keys
val top10Keys2 = combined2.toList.sortBy(_._2).takeRight(k).toMap.keys

(top10Keys1 ++ top10Keys2).map(key => (key, combined1.getOrElse(key, 0L) + combined2.getOrElse(key, 0L)))

I use this as follows:

// input is RDD[(String, String)]
val topKValueCount: RDD[(String, Map[String, Long])] = input.combineByKey(

One optimization to the current code would be to use min-queue during mergeCombiners.

I'm more concerned about the network I/O. Would it be possible that once I do the merging in a Partition, I only send the topK entries from this partition to the driver, instead of sending the entire Map, which I'm doing in the current case.

Highly appreciate any feedback.


I've been able to solve this satisfactorily as follows. The trick is to break the problem into 2 parts, in the first part combine the key and its value together, to get the count of the times the same k,v occurs and then use this with the new topk combiner to fetch the topk occurring values.

case class TopKCount(topK: Int = 10) {

  //sort and trim a traversable (String, Long) tuple by _2 value of the tuple
  def topNs(xs: TraversableOnce[(String, Long)], n: Int) = {
    var ss = List[(String, Long)]()
    var min = Long.MaxValue
    var len = 0
    xs foreach { e =>
      if (len < n || e._2 > min) {
        ss = (e :: ss).sortBy((f) => f._2)
        min = ss.head._2
        len += 1
      if (len > n) {
        ss = ss.tail
        min = ss.head._2
        len -= 1

  //seed a list for each key to hold your top N's with your first record
  def createCombiner(value: (String, Long)): Seq[(String, Long)] = Seq(value)

  //add the incoming value to the accumulating top N list for the key
  def mergeValue(combined: Seq[(String, Long)], value: (String, Long)): Seq[(String, Long)] =
    topNs(combined ++ Seq((value._1, value._2)), topK)

  //merge top N lists returned from each partition into a new combined top N list
  def mergeCombiners(combined1: Seq[(String, Long)], combined2: Seq[(String, Long)]): Seq[(String, Long)] =
    topNs(combined1 ++ combined2, topK)

object FieldValueCount {
  //Field Value Count combiners
  def createCombiner(value: String): (Double, Long) =
    if (Try(value.toDouble).isSuccess) (value.toDouble, 1L)
    else (0.0, 1L)

  def mergeValue(combined: (Double, Long), value: String): (Double, Long) =
    if (Try(value.toDouble).isSuccess) (combined._1 + value.toDouble, combined._2 + 1L)
    else (combined._1, combined._2 + 1L)

  def mergeCombiners(combined1: (Double, Long), combined2: (Double, Long)): (Double, Long) =
    (combined1._1 + combined2._1, combined1._2 + combined2._2)

// Example usage. Here input is the RDD[(String, String)]
val topKCount = TopKCount(10)


// combine the k,v from the original input to convert it into (k, (v, count))
val combined: RDD[(String, (String, Long))] = => (v._1 + "|" + v._2, 1L))
  .reduceByKey(_ + _).map(k => (k._1.split("\\|", -1).head, (k._1.split("\\|", -1).drop(1).head, k._2)))

val topKValueCount: RDD[(String, Seq[(String, Long)])] = combined.combineByKey(

The TopKCount has been converted to a case class so that we can change the value of k as needed. It can be made as an object if k is not needed to be variable.