Neysofu Neysofu - 1 month ago 24
Scala Question

Fastest weighted random algorithm in Scala?

I'm writing a server-side module for a Scala based project, and I need to find the fastest way to perform a weighted random number generation between some

Int
weights. The method should be as fastest as possible since it will be called very often.

Now, this is what I came up with:

import scala.util.Random

trait CumulativeDensity {

/** Returns the index result of a binary search to find @n in the discrete
* @cdf array.
*/
def search(n: Int, cdf: Array[Int]): Int = {
val i: Int = cdf.indexWhere(_ != 0)
if (i<0 | n<=cdf(i))
i
else
search(n-cdf(i), {cdf.update(i,0); cdf})
}

/** Returns the cumulative density function (CDF) of @list (in simple terms,
* the cumulative sums of the weights).
*/
def cdf(list: Array[Int]) = list.map{
var s = 0;
d => {s += d; s}
}
}


And I define the main method with this piece of code:

def rndWeighted(list: Array[Int]): Int =
search(Random.nextInt(list.sum + 1), cdf(list))


However, it still isn't fast enough. Is there any kind of black magic that makes unnecessary to iterate over the list since its start (libraries, built-ins, heuristics)?

EDIT: this is the final version of the code (much faster now):

def search(n: Int, cdf: Array[Int]): Int = {
if (n > cdf.head)
1 + search(n-cdf.head, cdf.tail)
else
0
}

elm elm
Answer

Instead of cdf.update(i,0) and passing the entire cdf back to cdf.indexWhere(_ != 0) in the next recursive call, consider

cdf.splitAt(i)

and passing only the elements on the right of i, so in the following recursion, indexWhere scans a smaller array. Note the array size being monotonic decreasing at each recursive call ensures termination.