JaKu JaKu - 3 months ago 18
Scala Question

Parallel recursion in scala

I am trying to parallelize recursive calls of sudoku solver from 25 lines Sudoku solver in Scala. I've changed their

Fold
into
reduce


def reduce(f: (Int, Int) => Int, accu: Int, l: Int, u: Int): Int = {
accu + (l until u).toArray.reduce(f(accu, _) + f(accu, _))
}


which if run sequentially works fine, but when I change it into

accu + (l until u).toArray.par.reduce(f(accu, _) + f(accu, _))


the recursion reaches the bottom much more often and generates false solutions. I thought, that it will execute the bottom level recursion and work it's way up, but doesn't seem to do so.


I've also tried futures

def parForFut2(f: (Int, Int) => Int, accu: Int, l: Int, u: Int): Int = {
var sum: Int = accu
val vals = l until u
vals.foreach(t => scala.actors.Futures.future(sum + f(accu, t)))
sum
}


which appears to have the same problem as the
par.reduce
. I would appreciate any comment. The whole code is here:

object SudokuSolver extends App {
// The board is represented by an array of string
val source = scala.io.Source.fromFile("./puzzle")
val lines = (source.getLines).toArray
var m: Array[Array[Char]] = for (
str <- lines;
line: Array[Char] = str.toArray
) yield line
source.close()

// For printing m
def print = {
Console.println("");
refArrayOps(m) map (carr => Console.println(new String(carr)))
}

// The test for validity of n on position x,y
def invalid(i: Int, x: Int, y: Int, n: Char): Boolean =
i < 9 && (m(y)(i) == n || m(i)(x) == n ||
m(y / 3 * 3 + i / 3)(x / 3 * 3 + i % 3) == n || invalid(i + 1, x, y, n))

// Looping over a half-closed range of consecutive Integers [l..u)
// is factored out Into a higher-order function
def parReduce(f: (Int, Int) => Int, accu: Int, l: Int, u: Int): Int = {
accu + (l until u).toArray.par.reduce(f(accu, _) + f(accu, _))
}

// The search function examines each position on the board in turn,
// trying the numbers 1..9 in each unfilled position
// The function is itself a higher-order fold, accumulating the value
// accu by applying the given function f to it whenever a solution m
// is found
def search(x: Int, y: Int, f: (Int) => Int, accu: Int): Int = Pair(x, y) match {
case Pair(9, y) => search(0, y + 1, f, accu) // next row
case Pair(0, 9) => f(accu) // found a solution - print it and continue
case Pair(x, y) => if (m(y)(x) != '0') search(x + 1, y, f, accu) else
parForFut1((accu: Int, n: Int) =>
if (invalid(0, x, y, (n + 48).asInstanceOf[Char])) accu else {
m(y)(x) = (n + 48).asInstanceOf[Char];
val newaccu = search(x + 1, y, f, accu);
m(y)(x) = '0';
newaccu
}, accu, 1, 10)
}

// The main part of the program uses the search function to accumulate
// the total number of solutions
Console.println("\n" + search(0, 0, i => { print; i + 1 }, 0) + " solution(s)")
}

Answer

After Andreas comment I changed the m: Array[Array[Char]] into m: List[List[Char]] which prevents any unnecessary and unwanted changes to it. The final looping method is

    def reduc(f: (Int, Int) => Int, 
                  accu: Int, l: Int, u: Int, m1: List[List[Char]]):Int =
    accu + (l until u).toArray.par.reduce(f(accu, _) + f(accu, _))

and I had to pass m as an argument to each used function, so every one of them had its own instance of it. The whole code:

    object SudokuSolver extends App{
      // The board is represented by an Array of strings (Arrays of Chars),
      val source = scala.io.Source.fromFile("./puzzle")

      val lines = source.getLines.toList             
      val m: List[List[Char]] = for (
        str <- lines;
        line: List[Char] = str.toList
      ) yield line 
      source.close()

      // For prInting m
      def printSud(m: List[List[Char]]) = {
        Console.println("")
        m map (println)
      }                                               

      Console.println("\nINPUT:")                     
      printSud(m)  

      def invalid(i:Int, x:Int, y:Int, n:Char,m1: List[List[Char]]): Boolean =
        i < 9 && (m1(y)(i) == n || m1(i)(x) == n ||
          m1(y / 3 * 3 + i / 3)(x / 3 * 3 + i % 3) == n ||
          invalid(i + 1, x, y, n, m1))

      def reduc(f: (Int, Int) => Int, accu: Int, l: Int, u: Int, 
                m1: List[List[Char]]): Int =
        accu + (l until u).toArray.par.reduce(f(accu, _) + f(accu, _))

      def search(x: Int, y: Int, accu: Int, m1: List[List[Char]]): Int = 
        Pair(x, y) match {
          case Pair(9, y) => search(0, y + 1, accu, m1) // next row
          case Pair(0, 9) => { printSud(m1); accu + 1 } // found a solution
          case Pair(x, y) =>
            if (m1(y)(x) != '0')
              search(x + 1, y, accu, m1) // place is filled, we skip it.
            else // box is not filled, we try all n in {1,...,9}
              reduc((accu: Int, n: Int) => {
                if (invalid(0, x, y, (n + 48).asInstanceOf[Char], m1))
                  accu
                else { // n fits here
                  val line = List(m1(y).patch(x, Seq((n + 48).asInstanceOf[Char]), 1))
                  val m2 = m1.patch(y, line, 1)
                  val newaccu = search(x + 1, y, accu, m2);
                  val m3 = m1.patch(y, m1(y).patch(x, Seq(0), 1), 1)
                  newaccu
                }
            }, accu, 1, 10, m1)
      }                                               

      Console.println("\n" + search(0, 0, 0, m) + " solution(s)")

    }
Comments