Amir Afghani Amir Afghani - 3 months ago 13
Scala Question

Interleaving iterators

I wrote the following code, expecting the last

print
method to show the elements of both iterators combined. Instead it only shows the elements of
perfectSquares
. Can someone explain this to me?

object Fuge {

def main(args: Array[String]) : Unit = {

perfectSquares.takeWhile(_ < 100).foreach(square => print(square + " "))
println()
triangles.takeWhile(_ < 100).foreach(triangle => print(triangle + " "))
println()
(perfectSquares++triangles).takeWhile(_ < 100).foreach(combine => print(combine + " "))

}

def perfectSquares : Iterator[Int] = {
Iterator.from(1).map(x => x * x)
}

def triangles : Iterator[Int] = {
Iterator.from(1).map(n => (n * (n + 1)/2))
}

}


OUTPUT:

1 4 9 16 25 36 49 64 81
1 3 6 10 15 21 28 36 45 55 66 78 91
1 4 9 16 25 36 49 64 81

Answer

From the documentation on takeWhile:

  /** Takes longest prefix of values produced by this iterator that satisfy a predicate.
   *
   *  @param   p  The predicate used to test elements.
   *  @return  An iterator returning the values produced by this iterator, until
   *           this iterator produces a value that does not satisfy
   *           the predicate `p`.
   *  @note    Reuse: $consumesAndProducesIterator
   */

What this means is that the iterator stops at that juncture. What you've created is an iterator that goes far past 100 and then, at some point, starts off at 1 again. But takeWhile won't go that far because it's already run into a number higher than 100. See:

object Fuge {
  def main(args: Array[String]) : Unit = {

    perfectSquares.takeWhile(_ < 100).foreach(square => print(square + " "))
    println()
    triangles.takeWhile(_ < 100).foreach(triangle => print(triangle + " "))
    println()
    def interleave (a: Iterator[Int], b: Iterator[Int]): Stream[Int] = {
      if (a.isEmpty || b.isEmpty) { Stream.empty }
      else {
        a.next() #:: b.next() #:: interleave(a, b)
      }
    }
    lazy val interleaved = interleave(perfectSquares, triangles)
    interleaved.takeWhile(_ < 100).foreach(combine => print(combine + " "))
  }

  def perfectSquares : Iterator[Int] = {
    Iterator.from(1).map(x => x * x)
  }

  def triangles : Iterator[Int] = {
    Iterator.from(1).map(n => (n * (n + 1)/2))
  }
}

Here I'm using a stream to lazily evaluate the sequence of integers. In this way we can ensure interleaving. Note that this is just interleaved, not sorted.

This yields:

1 4 9 16 25 36 49 64 81 
1 3 6 10 15 21 28 36 45 55 66 78 91 
1 1 4 3 9 6 16 10 25 15 36 21 49 28 64 36 81 45

To sort during a stream, you need a BufferedIterator and to change up the interleave function a bit. This is because calling next() advances the iterator - you can't go back. And you also can't know how many items you need from list a before you need an item from list b, and vice versa. But BufferedIterator allows you to call head, which is a 'peek' and does not advance the iterator. Now the code becomes:

object Fuge {
  def main(args: Array[String]) : Unit = {
    perfectSquares.takeWhile(_ < 100).foreach(square => print(square + " "))
    println()
    triangles.takeWhile(_ < 100).foreach(triangle => print(triangle + " "))
    println()
    def interleave (a: BufferedIterator[Int], b: BufferedIterator[Int]): Stream[Int] = {
      if (a.isEmpty || b.isEmpty) { Stream.empty }
      else if (a.head <= b.head){
        a.next() #:: interleave(a, b)
      } else {
        b.next() #:: interleave(a, b)
      }
    }
    lazy val interleaved = interleave(perfectSquares.buffered, triangles.buffered)
    interleaved.takeWhile(_ < 100).foreach(combine => print(combine + " "))
  }

  def perfectSquares : Iterator[Int] = {
    Iterator.from(1).map(x => x * x)
  }

  def triangles : Iterator[Int] = {
    Iterator.from(1).map(n => (n * (n + 1)/2))
  }
}

And the output is:

1 4 9 16 25 36 49 64 81 
1 3 6 10 15 21 28 36 45 55 66 78 91 
1 1 3 4 6 9 10 15 16 21 25 28 36 36 45 49 55 64 66 78 81 91