Run2 Run2 - 4 months ago 36
Scala Question

How to get StratifiedKFold in Scala Spark MLLib

I searched a bit and not finding any - wrote a StratifiedKFold method which can be used with Scala, Spark (MLlib). I am posting the answer below

Answer
def StratifiedKFold(nSamples: Int, k: Int, labels: List[Int],shuffle: Boolean = false):  (Map[Int,List[List[Int]]],Int)= {

    var idxs = (0 until nSamples).toArray
    val unqLabels = labels.distinct
    val noOfLabels = unqLabels.length

    val idxsbylabel = idxs.groupBy { x => labels(x) }
    var stratifiedidxs: Map[Int,List[List[Int]]] = Map(1 -> List(List(1))) 

    for ( i <- 0 to noOfLabels-1){
        val labelsgroup_i_arr = if(shuffle) bshuffle(idxsbylabel(i).toArray) else idxsbylabel(i).toArray
        val noOfParts = if(labelsgroup_i_arr.length%k==0) labelsgroup_i_arr.length/k else (labelsgroup_i_arr.length/k)+1
        val labelsgroup_i_lst = List.concat(labelsgroup_i_arr)
        stratifiedidxs = stratifiedidxs + (i -> labelsgroup_i_lst.grouped(noOfParts).toList)
    }
    (stratifiedidxs,noOfLabels)
}