djanderson djanderson - 3 months ago 25
Scala Question

Processing an Akka Stream with a One-Time Header

I have an application which receives a TCP socket connection which will send data in the form:

n{json}bbbbbbbbbb...


where
n
is the length of the following
json
in bytes, and the
json
might be something like
{'splitEvery': 5}
, which will dictate how I break up and process the potentially infinite string of bytes to follow.

I want to process this stream with Akka in Scala. I think
streams
are the right tool for this, but I am having a hard time finding an example that uses streams with distinct processing stages. Most stream flows seem to do the same thing over and over, like the
prefixAndTail
example here. That is very close to how I want to process the
n{json}
part of my stream, but the difference is I only need to do this once per connection and then move on to a different "stage" of processing.

Can anyone point me to an example of using Akka streams with distinct stages?

Answer

Here's a GraphStage which processes a stream of ByteStrings:

  • Extract chunk size from header
  • Emit ByteStrings of the specified chunk size
import akka.stream.{Attributes, FlowShape, Inlet, Outlet}
import akka.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler}
import akka.util.ByteString

class PreProcessor extends GraphStage[FlowShape[ByteString, ByteString]] {

  val in: Inlet[ByteString] = Inlet("ParseHeader.in")
  val out: Outlet[ByteString] = Outlet("ParseHeader.out")

  override val shape = FlowShape.of(in, out)

  override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
    new GraphStageLogic(shape) {

      var buffer = ByteString.empty
      var chunkSize: Option[Int] = None
      private var upstreamFinished = false

      private val headerPattern = """^\d+\{"splitEvery": (\d+)\}""".r

      /**
        * @param data The data to parse.
        * @return The chunk size and header size if the header
        * could be parsed.
        */
      def parseHeader(data: ByteString): Option[(Int, Int)] =
      headerPattern.
        findFirstMatchIn(data.decodeString("UTF-8")).
        map { mtch => (mtch.group(1).toInt, mtch.end) }

      setHandler(out, new OutHandler {
        override def onPull(): Unit = {
          if (isClosed(in)) emit()
          else pull(in)
        }
      })

      setHandler(in, new InHandler {
        override def onPush(): Unit = {
          val elem = grab(in)
          buffer ++= elem
          if (chunkSize.isEmpty) {
            parseHeader(buffer) foreach { case (chunk, headerSize) =>
              chunkSize = Some(chunk)
              buffer = buffer.drop(headerSize)
            }
          }
          emit()
        }

        override def onUpstreamFinish(): Unit = {
          upstreamFinished = true
          if (chunkSize.isEmpty || buffer.isEmpty) completeStage()
          else {
            if (isAvailable(out)) emit()
          }
        }
      })

      private def continue(): Unit =
        if (isClosed(in)) completeStage()
        else pull(in)

      private def emit(): Unit = {
        chunkSize match {
          case None => continue()
          case Some(size) =>
            if (upstreamFinished && buffer.isEmpty ||
               !upstreamFinished && buffer.size < size) {
              continue()
            } else {
              val (chunk, nextBuffer) = buffer.splitAt(size)
              buffer = nextBuffer
              push(out, chunk)
            }
        }
      }
    }
}

And the test case to illustrate the usage:

import akka.actor.ActorSystem
import akka.stream._
import akka.stream.scaladsl.Source
import akka.util.ByteString
import org.scalatest._

import scala.concurrent.Await
import scala.concurrent.duration._
import scala.util.Random

class PreProcessorSpec extends FlatSpec {

  implicit val system = ActorSystem("Test")
  implicit val materializer = ActorMaterializer()

  val random = new Random

  "" should "" in {

    def splitRandom(s: String, n: Int): List[String] = s match {
      case "" => Nil
      case s =>
        val (head, tail) = s splitAt random.nextInt(n)
        head :: splitRandom(tail, n)
    }

    val input = """17{"splitEvery": 5}aaaaabbbbbcccccddd"""

    val strings = splitRandom(input, 7)
    println(strings.map(s => s"[$s]").mkString(" ") + "\n")

    val future = Source.fromIterator(() => strings.iterator).
      map(ByteString(_)).
      via(new PreProcessor()).
      map(_.decodeString("UTF-8")).
      runForeach(println)

    Await.result(future, 5 seconds)
  }

}

Example output:

[17{"] [splitE] [very"] [] [: 5}] [aaaaa] [bbb] [bbcccc] [] [cddd]

aaaaa
bbbbb
ccccc
ddd