Georg Heiler Georg Heiler - 1 year ago 73
Scala Question

porting python to scala

I am trying to port python code (spark sql distance to nearest holiday)

last_holiday = index.value[0]
for next_holiday in index.value:
if next_holiday >= date:
last_holiday = next_holiday
if last_holiday > date:
last_holiday = None
if next_holiday < date:
next_holiday = None

to scala. I do not (yet) have so much scala experience, but
does not seem clean / the scala way to do it. Please, can you show me how to "cleanly" port this to scala.

breakable {
for (next_holiday <- indexAT.value) {
val next = next_holiday.toLocalDate
println("next ", next)
println("last ", last_holiday)

if (next.isAfter(current) || next.equals(current)) break
// check do I actually get here?
last_holiday = Option(next)
} // TODO this is so not scala and ugly ...
if (last_holiday.isDefined) {
if (last_holiday.get.isAfter(current)) {
last_holiday = None
if (last_holiday.isDefined) {
if (last_holiday.get.isBefore(current)) {
// TODO use one more var because out of scope
next = None

Here the same code in a bit more context
Also, I am not sure how "big" I should put the break - but I hope to get rid of it in a scala native port of the code.

Answer Source

So this isn't a direct port but I think it is closer to idiomatic Scala. I would treat the list of holidays as a list of sequential pairs and then find which pair the input date lies between.

Here is a full example:

scala> import java.sql.Date
import java.sql.Date

scala> import java.text.SimpleDateFormat
import java.text.SimpleDateFormat

scala> :pa
// Entering paste mode (ctrl-D to finish)
def parseDate(in: String): java.sql.Date =
    val formatter = new SimpleDateFormat("MM/dd/yyyy")
    val d = formatter.parse(in)
    new java.sql.Date(d.getTime());
// Exiting paste mode, now interpreting.
parseDate: (in: String)java.sql.Date

scala> val holidays = Seq("11/24/2016", "12/25/2016", "12/31/2016").map(parseDate)
holidays: Seq[java.sql.Date] = List(2016-11-24, 2016-12-25, 2016-12-31)

scala> val hP = sc.broadcast(
hP: org.apache.spark.broadcast.Broadcast[Seq[(java.sql.Date, java.sql.Date)]] = Broadcast(4)

scala> def geq(d1: Date, d2: Date) = d1.after(d2) || d1.equals(d2)
geq: (d1: java.sql.Date, d2: java.sql.Date)Boolean

scala> def leq(d1: Date, d2: Date) = d1.before(d2) || d1.equals(d2)
leq: (d1: java.sql.Date, d2: java.sql.Date)Boolean

scala> :pa
// Entering paste mode (ctrl-D to finish)
val findNearestHolliday = udf((inDate: Date) => {
    val hP_l = hP.value
    val dates = hP_l.collectFirst{case (d1, d2) if (geq(inDate, d1) && leq(inDate, d2)) => (Some(d1), Some(d2))}
    dates.getOrElse(if (leq(inDate, hP_l.head._1)) (None, Some(hP_l.head._1)) else (Some(hP_l.last._2), None))
// Exiting paste mode, now interpreting.
findNearestHolliday: org.apache.spark.sql.UserDefinedFunction = UserDefinedFunction(<function1>,StructType(StructField(_1,DateType,true), StructField(_2,DateType,true)),List(DateType))

scala> val df = Seq((1, parseDate("11/01/2016")), (2, parseDate("12/01/2016")), (3, parseDate("01/01/2017"))).toDF("id", "date")
df: org.apache.spark.sql.DataFrame = [id: int, date: date]

scala> val df2 = df.withColumn("nearestHollidays", findNearestHolliday($"date"))
df2: org.apache.spark.sql.DataFrame = [id: int, date: date, nearestHollidays: struct<_1:date,_2:date>]

| id|      date|    nearestHollidays|
|  1|2016-11-01|   [null,2016-11-24]|
|  2|2016-12-01|[2016-11-24,2016-...|
|  3|2017-01-01|   [2016-12-31,null]|

scala> df2.foreach{println}