简化Scala中的表达式

时间:2011-07-02 18:21:21

标签: scala

我有这样的案例类:

abstract class Tree

case class Sum(l: Tree, r: Tree) extends Tree

case class Var(n: String) extends Tree

case class Const(v: Int) extends Tree

现在我写下这样的对象:

object Main {

  type Environment = String => Int

  def derive(t: Tree, v: String): Tree = t match {
    case Sum(l, r) => Sum(derive(l, v), derive(r, v))
    case Var(n) if (v == n) => Const(1)
    case _ => Const(0)
  }

  def eval(t: Tree, env: Environment): Int = t match {
    case Sum(l, r) => eval(l, env) + eval(r, env)
    case Var(n) => env(n)
    case Const(v) => v
  }

  def simple(t: Tree): Const = t match {
    case Sum(l, r) if (l.isInstanceOf[Const] && r.isInstanceOf[Const]) => Const(l.asInstanceOf[Const].v + r.asInstanceOf[Const].v)
    case Sum(l, r) if (l.isInstanceOf[Sum] && r.isInstanceOf[Sum]) => Const(simple(l).v+ simple(r).v)
    case Sum(l, r) if (l.isInstanceOf[Sum]) => Const(simple(l).v + r.asInstanceOf[Const].v)
    case Sum(l, r) if (r.isInstanceOf[Sum]) => Const(simple(r).v + l.asInstanceOf[Const].v)
  }

  def main(args: Array[String]) {
    val exp: Tree = Sum(Sum(Var("x"), Var("x")), Sum(Const(7), Var("y")))
    val env: Environment = {
      case "x" => 5
      case "y" => 7
    }
    println("Expression: " + exp)
    println("Evaluation with x=5, y=7: " + eval(exp, env))
    println("Derivative relative to x:\n " + derive(exp, "x"))
    println("Derivative relative to y:\n " + derive(exp, "y"))
    println("Simplified expression:\n" + simple(derive(exp, "x")))
  }


}

我是scala的新手。是否可以使用少量代码编写方法simple并且可能采用scala方式?

感谢您的建议。

4 个答案:

答案 0 :(得分:6)

你快到了。在Scala中,提取器可以嵌套:

def simple(t: Tree): Const = t match {
  case Sum(Const(v1), Const(v2)) => Const(v1 + v2)
  case Sum(s1 @ Sum(_,_), s2 @ Sum(_, _)) => Const(simple(s1).v+ simple(s2).v)
  case Sum(s @ Sum(_, _), Const(v)) => Const(simple(s).v + v)
  case Sum(Const(v), s @ Sum(_, _)) => Const(simple(s).v + v)
}

当然,这会给你一些关于不完整匹配的警告,并且sx @ Sum(_,_)反复暗示可能有更好的方法,包括在根级别匹配Const和Var并进行更多递归要求简单。

答案 1 :(得分:1)

虽然这个问题已经结束,但我认为这个版本应该更好一些,

def simplify(t: Tree): Tree = t match {
    case Sum(Const(v1), Const(v2)) => Const(v1 + v2)
    case Sum(Const(v1), Sum(Const(v2), rr)) => simplify(Sum(Const(v1 + v2), simplify(rr)))
    case Sum(l, Const(v)) => simplify(Sum(Const(v), simplify(l)))
    case Sum(l, Sum(Const(v), rr)) => simplify(Sum(Const(v), simplify(Sum(l, rr))))
    case Sum(Sum(ll, lr), r) => simplify(Sum(ll, simplify(Sum(lr, r))))
    case Sum(Var(n), r) => Sum(simplify(r), Var(n))
    case _ => t
}

它似乎适用于"复杂"带变量的表达式。

答案 2 :(得分:0)

只是一个小小的改进:

def derive(t: Tree, v: String): Tree = t match {
    case Sum(l, r) => Sum(derive(l, v), derive(r, v))
    case Var(`v`) => Const(1)
    case _ => Const(0)
}

答案 3 :(得分:0)

这个怎么样:

def simplify(t: Tree): Tree = t match {
    case Sum(Const(v1),Const(v2)) => Const(v1+v2)
    case Sum(left,right) => simplify(Sum(simplify(left),simplify(right)))
    case _ => t //Not necessary, but for completeness
}

请注意,它返回一个Tree,而不是Const,因此它应该能够简化带变量的树。

我正在学习Scala,所以对于为什么这不起作用等任何建议都非常受欢迎: - )


编辑:刚刚发现第二种情况在使用变量时会导致无限循环。用以下代替:

case Sum(left,right) => Sum(simplify(left),simplify(right))

不幸的是,当leftright返回Const时会中断,这可以进一步简化(例如Sum(Const(2),Const(3)))。