将python移植到scala

时间:2016-11-23 14:57:44

标签: python scala

我正在尝试移植python代码(spark sql distance to nearest holiday

last_holiday = index.value[0]
    for next_holiday in index.value:
        if next_holiday >= date:
            break
        last_holiday = next_holiday
    if last_holiday > date:
        last_holiday = None
    if next_holiday < date:
        next_holiday = None
斯卡拉。我(还)没有那么多scala经验,但是break似乎并不干净/ scala方式。拜托,你能告诉我如何“干净地”将它移植到scala。

breakable {
      for (next_holiday <- indexAT.value) {
        val next = next_holiday.toLocalDate
        println("next ", next)
        println("last ", last_holiday)

        if (next.isAfter(current) || next.equals(current)) break
        // check do I actually get here?
        last_holiday = Option(next)
      } // TODO this is so not scala and ugly ...
      if (last_holiday.isDefined) {
        if (last_holiday.get.isAfter(current)) {
          last_holiday = None
        }
      }
      if (last_holiday.isDefined) {
        if (last_holiday.get.isBefore(current)) {
          // TODO use one more var because out of scope
          next = None
        }
      }
    }

这里的代码更多一点https://gist.github.com/geoHeil/ff513b97a2b3e16241fdd9c8b0f3bdfb 此外,我不确定我应该放弃这个“大” - 但我希望在代码的scala本地端口中删除它。

2 个答案:

答案 0 :(得分:2)

所以这不是一个直接的端口,但我认为它更接近惯用的Scala。我会将假期列表视为连续对的列表,然后找出输入日期介于哪一对。

以下是一个完整的例子:

scala> import java.sql.Date
import java.sql.Date

scala> import java.text.SimpleDateFormat
import java.text.SimpleDateFormat

scala> :pa
// Entering paste mode (ctrl-D to finish)
def parseDate(in: String): java.sql.Date =
{
    val formatter = new SimpleDateFormat("MM/dd/yyyy")
    val d = formatter.parse(in)
    new java.sql.Date(d.getTime());
}
// Exiting paste mode, now interpreting.
parseDate: (in: String)java.sql.Date

scala> val holidays = Seq("11/24/2016", "12/25/2016", "12/31/2016").map(parseDate)
holidays: Seq[java.sql.Date] = List(2016-11-24, 2016-12-25, 2016-12-31)

scala> val hP = sc.broadcast(holidays.zip(holidays.tail))
hP: org.apache.spark.broadcast.Broadcast[Seq[(java.sql.Date, java.sql.Date)]] = Broadcast(4)

scala> def geq(d1: Date, d2: Date) = d1.after(d2) || d1.equals(d2)
geq: (d1: java.sql.Date, d2: java.sql.Date)Boolean

scala> def leq(d1: Date, d2: Date) = d1.before(d2) || d1.equals(d2)
leq: (d1: java.sql.Date, d2: java.sql.Date)Boolean

scala> :pa
// Entering paste mode (ctrl-D to finish)
val findNearestHolliday = udf((inDate: Date) => {
    val hP_l = hP.value
    val dates = hP_l.collectFirst{case (d1, d2) if (geq(inDate, d1) && leq(inDate, d2)) => (Some(d1), Some(d2))}
    dates.getOrElse(if (leq(inDate, hP_l.head._1)) (None, Some(hP_l.head._1)) else (Some(hP_l.last._2), None))
})
// Exiting paste mode, now interpreting.
findNearestHolliday: org.apache.spark.sql.UserDefinedFunction = UserDefinedFunction(<function1>,StructType(StructField(_1,DateType,true), StructField(_2,DateType,true)),List(DateType))

scala> val df = Seq((1, parseDate("11/01/2016")), (2, parseDate("12/01/2016")), (3, parseDate("01/01/2017"))).toDF("id", "date")
df: org.apache.spark.sql.DataFrame = [id: int, date: date]

scala> val df2 = df.withColumn("nearestHollidays", findNearestHolliday($"date"))
df2: org.apache.spark.sql.DataFrame = [id: int, date: date, nearestHollidays: struct<_1:date,_2:date>]

scala> df2.show
+---+----------+--------------------+
| id|      date|    nearestHollidays|
+---+----------+--------------------+
|  1|2016-11-01|   [null,2016-11-24]|
|  2|2016-12-01|[2016-11-24,2016-...|
|  3|2017-01-01|   [2016-12-31,null]|
+---+----------+--------------------+

scala> df2.foreach{println}
[3,2017-01-01,[2016-12-31,null]]
[1,2016-11-01,[null,2016-11-24]]
[2,2016-12-01,[2016-11-24,2016-12-25]]

答案 1 :(得分:0)

我已尝试使用scala实现此功能:

scala> import java.text.SimpleDateFormat
import java.text.SimpleDateFormat

scala> import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeUnit

scala> val sdf = new SimpleDateFormat("dd/MM/yyyy")
sdf: java.text.SimpleDateFormat = java.text.SimpleDateFormat@d936eac0

//Here I've just assumed that the 15th of every other month is a public holiday
scala> val publicHolidays = for(interval <- 4 to 12 by 2) yield sdf.parse(s"15/$interval/2016")
publicHolidays: scala.collection.immutable.IndexedSeq[java.util.Date] = Vector(Fri Apr 15 00:00:00 BST 2016, Wed Jun 15 00:00:00 BST 2016, Mon Aug 15 00:00:00 BST 2016, Sat Oct 15 00:00:00 BST 2016, Thu Dec 15 00:00:00 GMT 2016)

//Today's date
scala> val currentDate = sdf.parse("23/11/2016")
currentDate: java.util.Date = Wed Nov 23 00:00:00 GMT 2016

scala> def findDaysTillNextHoliday: Long = {
     | val nextHolday = publicHolidays.toList.filter(_.after(currentDate)).head
     | TimeUnit.DAYS.convert(nextHolday.getTime - currentDate.getTime, TimeUnit.MILLISECONDS)
     | }
findDaysTillNextHoliday: Long

scala> findDaysTillNextHoliday
res0: Long = 22 //i.e 22 days till the next holiday which is 15th of december 2016

自上次假期以来的几天:

def findDaysSinceLastHoliday: Long = {
      | val lastHoliday = publicHolidays.toList.filter(_.before(currentDate)).last
      | TimeUnit.DAYS.convert(currentDate.getTime - lastHoliday.getTime, TimeUnit.MILLISECONDS)
      |}
findDaysSinceLastHoliday: Long

scala> findDaysSinceLastHoliday
res1: Long = 39 //i.e 39 days since the last holiday which was 15th of October 2016