在以日期为键的Seq上滚动总和

时间:2019-12-28 22:41:33

标签: scala

我正在从事一个小型个人项目,并试图将代码库从Python重写为Scala,以便使我成为更有能力的函数式程序员。

我正在使用Seq,其中包含库存数据,并且需要创建每天交易量的连续总和。

到目前为止,我的代码是:

import org.joda.time.DateTime
import org.joda.time.format.DateTimeFormat

case class SymbolData(date: DateTime, open: Double, high: Double, low: Double, close: Double, adjClose: Double, volume: Int)

def dateTimeHelper(date: String): DateTime = {
       DateTimeFormat.forPattern("yyyy-MM-dd").parseDateTime(date)
}

val sampleData: Seq[SymbolData] = Seq(
       SymbolData(dateTimeHelper("2019-01-01"), 1.0, 1.0, 1.0, 1.0, 1.0, 10),
       SymbolData(dateTimeHelper("2019-01-02"), 3.0, 2.0, 5.0, 2.0, 8.0, 20),
       SymbolData(dateTimeHelper("2019-01-03"), 1.0, 1.0, 1.0, 1.0, 1.0, 10),
       SymbolData(dateTimeHelper("2019-01-04"), 4.0, 3.0, 2.5, 2.3, 5.3, 7))

并非所有日期都可能存在,因此我认为使用滑动窗口并不适合。对于输出,我需要获取一个Seq of ints,其中包含最近2天的数据之和,例如:

Seq(10, 30, 30, 17) # 2019-01-01 has only 1 day with sum value of 10 since there is no data for 2018-12-31, 2019-01-02 would be 30 since we have 2nd and 1st of Jan present, etc...

这在基本python中并不是很难做到,但是对于Scala来说,似乎有很多选择(递归使用folds?),但是我在语法和实现方面都在挣扎。有人能对此有所启发吗?

1 个答案:

答案 0 :(得分:1)

您说“并非所有日期都可能存在”,但未指定如何处理日期间隔。

我猜这里的输出应该包括所有2天的总和,包括间隔天。

import java.time.LocalDate
import java.time.temporal.ChronoUnit.DAYS

case class SymbolData(date     : LocalDate
                     ,open     : Double
                     ,high     : Double
                     ,low      : Double
                     ,close    : Double
                     ,adjClose : Double
                     ,volume   : Int)

val sampleData: List[SymbolData] = List(
  SymbolData(LocalDate.parse("2019-01-01"), 1.0, 1.0, 1.0, 1.0, 1.0, 10),
  SymbolData(LocalDate.parse("2019-01-02"), 3.0, 2.0, 5.0, 2.0, 8.0, 20),
  SymbolData(LocalDate.parse("2019-01-03"), 1.0, 1.0, 1.0, 1.0, 1.0, 10),
  SymbolData(LocalDate.parse("2019-01-04"), 4.0, 3.0, 2.5, 2.3, 5.3, 7),
  // 1 day gap
  SymbolData(LocalDate.parse("2019-01-06"), 4.4, 3.3, 2.2, 2.3, 1.3, 13),
  // 2 day gap
  SymbolData(LocalDate.parse("2019-01-09"), 2.4, 2.2, 1.5, 3.1, 0.9, 21),
  SymbolData(LocalDate.parse("2019-01-10"), 2.4, 2.2, 1.5, 3.1, 0.9, 11)
)

val volByDate = sampleData.foldLeft(Map.empty[LocalDate,Int]){
  case (m,sd) => m + (sd.date -> sd.volume)
}.withDefaultValue(0)

val startDate = sampleData.head.date
val endDate   = sampleData.last.date

val rslt = List.unfold(startDate){ date =>  //<--Scala 2.13
  if (date isAfter endDate) None
  else
    Some(volByDate(date) + volByDate(date.minus(1L,DAYS)) -> date.plus(1L,DAYS))
}
//rslt: List[Int] = List(10, 30, 30, 17, 7, 13, 13, 0, 21, 32)