Neysofu - 1 year ago 142

Scala Question

I'm writing a server-side module for a Scala based project, and I need to find the fastest way to perform a weighted random number generation between some

`Int`

Now, this is what I came up with:

`import scala.util.Random`

trait CumulativeDensity {

/** Returns the index result of a binary search to find @n in the discrete

* @cdf array.

*/

def search(n: Int, cdf: Array[Int]): Int = {

val i: Int = cdf.indexWhere(_ != 0)

if (i<0 | n<=cdf(i))

i

else

search(n-cdf(i), {cdf.update(i,0); cdf})

}

/** Returns the cumulative density function (CDF) of @list (in simple terms,

* the cumulative sums of the weights).

*/

def cdf(list: Array[Int]) = list.map{

var s = 0;

d => {s += d; s}

}

}

And I define the main method with this piece of code:

`def rndWeighted(list: Array[Int]): Int =`

search(Random.nextInt(list.sum + 1), cdf(list))

However, it still isn't fast enough. Is there any kind of black magic that makes unnecessary to iterate over the list since its start (libraries, built-ins, heuristics)?

`def search(n: Int, cdf: Array[Int]): Int = {`

if (n > cdf.head)

1 + search(n-cdf.head, cdf.tail)

else

0

}

Recommended for you: Get network issues from **WhatsUp Gold**. **Not end users.**

Answer Source

Instead of `cdf.update(i,0)`

and passing the entire `cdf`

back to `cdf.indexWhere(_ != 0)`

in the next recursive call, consider

```
cdf.splitAt(i)
```

and passing only the elements on the *right* of `i`

, so in the following recursion, `indexWhere`

scans a smaller array. Note the array size being monotonic decreasing at each recursive call ensures termination.