将我的代码从OCaml / F#映射到Scala - 一些问题

时间:2014-12-04 10:44:29

标签: scala f# ocaml

我在空闲时间学习Scala - 作为一项学习练习,我将一些OCaml代码that I wrote about in another StackOverflow question翻译成了Scala。由于我是Scala的新手,我很欣赏一些建议......

但在提出我的问题之前 - 这是最初的OCaml代码:

let visited = Hashtbl.create 200000

let rec walk xx yy =
    let addDigits number =
        let rec sumInner n soFar =
            match n with
            | x when x<10  -> soFar+x
            | x -> sumInner (n/10) (soFar + n mod 10) in
        sumInner number 0 in
    let rec innerWalk (totalSoFar,listOfPointsToVisit) =
        match listOfPointsToVisit with
        | [] -> totalSoFar
        | _ ->
            innerWalk (
                listOfPointsToVisit
                (* remove points that we've already seen *)
                |> List.filter (fun (x,y) ->
                    match Hashtbl.mem visited (x,y) with
                    | true -> false (* remove *)
                    | _    -> (Hashtbl.add visited (x,y) 1 ; true))
                (* increase totalSoFar and add neighbours to list *)
                |> List.fold_left
                    (fun (sum,newlist) (x,y) ->
                        match (addDigits x)+(addDigits y) with
                        | n when n<26 ->
                            (sum+1,(x+1,y)::(x-1,y)::(x,y+1)::(x,y-1)::newlist)
                        | n -> (sum,newlist))
                    (totalSoFar,[])) in
    innerWalk (0,[(xx,yy)])

let _ =
    Printf.printf "Points: %d\n" (walk 1000 1000)

...这里是我将其翻译成的Scala代码:

import scala.collection.mutable.HashMap

val visited = new HashMap[(Int,Int), Int]

def addDigits(number:Int) = {
    def sumInner(n:Int, soFar:Int):Int =
      if (n<10)
        soFar+n
      else
        sumInner(n/10, soFar+n%10)
    sumInner(number, 0)
}

def walk(xx:Int, yy:Int) = {
    def innerWalk(totalSoFar:Int, listOfPointsToVisit:List[(Int,Int)]):Int = {
        if (listOfPointsToVisit.isEmpty) totalSoFar
        else {
            val newStep = 
                listOfPointsToVisit.
                // remove points that we've already seen
                filter(tupleCoords => {
                    if (visited.contains(tupleCoords))
                        false
                    else {
                        visited(tupleCoords)=1 
                        true
                    }
                }).
                // increase totalSoFar and add neighbours to list
                foldLeft( (totalSoFar,List[(Int,Int)]()) )( (state,coords) => {
                    val (sum,newlist) = state
                    val (x,y) = coords
                    if (addDigits(x)+addDigits(y) < 26)
                        (sum+1,(x+1,y)::(x-1,y)::(x,y+1)::(x,y-1)::newlist)
                    else
                        (sum,newlist)
                });
            innerWalk(newStep._1, newStep._2)
        }
    }
    innerWalk(0, List((xx,yy)))
}

println("Points: " + walk(1000,1000))

Scala代码编译并正常工作,报告正确的结果。

但是...

  • 除非我遗漏了什么,否则我在Scala中找不到管道运算符(即OCaml和F#的|>)所以我使用了相应的列表方法(filterfold Left )。在这种情况下,最终结果与原始结果非常接近,但我想知道 - 对于功能型解决方案而言,管道运营商不是一种普遍有利且更通用的方法吗?为什么Scala不配备它?

  • 在Scala中,我必须专门启动我的折叠状态(这是(Int, List[Int,Int])的元组,其中包含特定类型的空列表。简单来说,List()没有删除它 - 我必须明确指定List[(Int,Int)](),否则我得到了一个......相当困难的错误信息。我根据上下文对其进行了解密 - 它抱怨Nothing - 我意识到这个微小代码中唯一的地方Nothing出现的类型可能是我的空列表。但是,与OCaml相比,结果更加丑陋......我能做得更好吗?

  • 同样,OCaml能够将折叠的结果元组作为参数传递给innerWalk。在Scala中,我不得不分配一个变量并调用尾递归调用 innerWalk(newStep._1, newStep._2)。元组和函数参数之间似乎没有等价 - 即我不能在具有两个参数的函数中传递2-arity的元组 - 同样地,我不能将函数的参数元组解析为变量(I必须明确地指定statecoords并在折叠函数体内解构它们。我错过了什么吗?

总的来说,我对结果很满意 - 我会说如果我们将此示例的OCaml代码分级为100%,那么Scala大约为85-90% - 它比OCaml更冗长,但它是更接近OCaml而不是Java。我只是想知道我是否充分利用了Scala,或者是否错过了一些可以改进代码的构造(更有可能)。

请注意,我避免将我原来的OCaml模式匹配映射到Scala,因为在这种情况下我认为它是过度的 - if表达式在两个地方都更清晰。

提前感谢您的任何帮助/建议。

P.S。旁注 - 我在walk调用周围添加了时序指令(从而避免了JVM的启动成本)并测量了我的Scala代码 - 它的运行速度大约是OCaml速度的50% - 这很有趣,完全是同样的速度我从Mono中执行F#等效代码(如果您关心这种比较,请参阅我原来的SO问题以获取F#代码)。由于我目前在企业环境中工作,50%的速度是一个价格,我很乐意付出代价来编写类似ML的代码并且仍然可以访问浩如烟海的JVM /.NET生态系统(数据库,Excel文件生成等)。对不起OCaml,我确实试过你 - 但是you can't fully "speak" Oracle: - )

编辑1 :在@senia和@lmm提出建议后,代码为significantly improved。希望来自@lmm的更多关于foldMap和Shapeless如何帮助的建议: - )

编辑2 :我使用scalaz中的flatMap进一步清理了代码 - gist is here。不幸的是,这一变化也造成了10倍的大幅放缓 - 猜测foldMap完成的列表连接比foldLeft的“只添加一个新节点”慢得多。想知道如何更改代码以快速添加...

编辑3 :在@lmm的另一个建议之后,我将scalaz-flatMap版本从使用List切换到using immutable.Vector:这有助于提高速度,从慢了10倍......只比原始代码慢2倍。那么,干净的代码还是2倍的速度?决定,决定......: - )

2 个答案:

答案 0 :(得分:5)

  • Scalaz确实提供了|>运算符,或者您可以自己编写一个运算符。一般来说,在Scala中对它的需求要少得多,因为对象有方法,正如你在一些翻译中看到的那样(例如somethingThatReturnsList.filter(...)在OCaml你必须写somethingThatReturnsList |> List.filter(...)所以它没有内置到语言中。但如果你需要它,那就在那里。
  • foldLeft有点笼统;您可以使用例如编写更清晰的代码Scalaz foldMap(在你的元组的情况下,你可能还需要shapeless-contrib,以便派生适当的类型类实例)。但基本上是的,Scala类型推断将不如OCaml可靠,你会发现自己必须添加显式类型注释(有时因为不清楚Nothing错误消息) - 它是我们支付传统费用的代价-OO extends继承。
  • 您可以使用(innerWalk _).tupled来获取一个带元组的函数。或者您可以编写函数来接受元组并利用参数自动调整来调用它们而不使用显式元组语法。但是,是的,没有多参数函数的通用编码(你可以使用Shapeless将它们转换成那种形式),我怀疑很大程度上是因为JVM的兼容性。我怀疑如果现在编写标准库,它将使用HList来表示所有内容,普通函数和HList表示之间会有等价,但这将是一个非常难以改变的向后兼容的方式。

您似乎使用了相当多的if,而且您可以使用某些功能,例如: visited.put(tupleCoords, 1)返回值是否被替换的布尔值,因此您可以将其用作filter调用的整个正文。正如我所说,如果您愿意使用Scalaz,foldLeft可以被重写为更清晰的foldMap。我怀疑整个递归循环可以用命名结构表达,但没有立即想到,所以也许没有。

答案 1 :(得分:1)

我附加了两个具有更多惯用Scala代码的替代版本(我还略微优化了算法)。在这种情况下,我不认为必要的解决方案有任何问题,实际上可以更容易理解。

  // Impure functional version
  def walk2(xx: Int, yy: Int) = {
    val visited = new mutable.HashSet[(Int, Int)]

    def innerWalk(totalSoFar: Int, listOfPointsToVisit: Seq[(Int, Int)]): Int = {
      if (listOfPointsToVisit.isEmpty) totalSoFar
      else {
        val newStep = listOfPointsToVisit.foldLeft((totalSoFar, Seq[(Int, Int)]())) {
          case ((sum, newlist), tupleCoords@(x, y)) =>
            if (visited.add(tupleCoords) && addDigits(x) + addDigits(y) < 26)
              (sum + 1, (x + 1, y) +: (x - 1, y) +: (x, y + 1) +: (x, y - 1) +: newlist)
            else
              (sum, newlist)
        }

        innerWalk(newStep._1, newStep._2)
      }
    }

    innerWalk(0, Seq((xx, yy)))
  }

  // Imperative version, probably fastest
  def walk3(xx: Int, yy: Int) = {
    val visited = new mutable.HashSet[(Int, Int)]()
    val toVisit = new mutable.Queue[(Int, Int)]()

    def add(x: Int, y: Int) {
      val tupleCoords = (x, y)

      if (visited.add(tupleCoords) && addDigits(x) + addDigits(y) < 26)
        toVisit += tupleCoords
    }

    add(xx, yy)
    var count = 0

    while (!toVisit.isEmpty) {
      count += 1
      val (x, y) = toVisit.dequeue()
      add(x + 1, y)
      add(x - 1, y)
      add(x, y + 1)
      add(x, y - 1)
    }

    count
  }

编辑:改进的功能版本,在命令式版本中使用Queue