将对象方法提升到类中

时间:2016-07-14 21:47:41

标签: scala macros

我想知道是否有办法从伴随对象中自动提升静态方法,因此您可以通过实例引用它们。像:

object Foo {
  def bar = 5
}
case class Foo()

val f = Foo()
f.bar

在删除实例参数的同时,自动提升需要类实例的对象方法的额外奖励点。 (假设this

object Foo {
  def bar(f: Foo) = ...
}
case class Foo()

val f = Foo()
f.bar // instead of Foo.bar(f)

1 个答案:

答案 0 :(得分:1)

如果您愿意接受一个编译器插件macro paradise(希望它可以与下一版本的编译器一起提供?),您可以制作一个可以完成所需操作的注释。具体来说(我认为这就是你想要的):

  • 会自动将需要类实例的对象方法作为第一个arg提升为类方法(首先是arg = this
  • 将提升剩余的对象方法作为类方法(忽略this

您必须配置内容以构建宏(或只是克隆this macro bare-bones repo,将以下宏代码粘贴到Macros.scala中并在Test.scala - sbt compile中进行试验照顾其他一切)。

import scala.annotation.StaticAnnotation
import scala.language.experimental.macros
import scala.reflect.macros._

class liftFromObject extends StaticAnnotation {
  def macroTransform(annottees: Any*) = macro liftFromObjectMacro.impl
}
object liftFromObjectMacro {
  def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
    import c.universe._

    annottees.map(_.tree) match {
      case List(q"case class $cla(..$fields) extends ..$classBases { ..$classBody }"
               ,q"object $obj extends ..$objectBases { ..$objectBody }") =>

        /* filter out from the object the functions we want to have in the class */
        val newMethods = objectBody collect {

          /* functions whose first arg indicates they are methods */
          case q"def $name($arg: $t, ..$args) = { ..$body }"
            if t.toString == cla.toString =>
            q"def $name(..$args) = { val $arg = this; ..$body }"

          /* other functions */
          case func@q"def $name(..$args) = { ..$body }" => func
        }

        /* return the modified class and companion object */
        c.Expr(q"""
           case class $cla(..$fields) extends ..$classBases {
             ..$classBody;
             ..$newMethods
           }
           object $obj extends ..$objectBases { ..$objectBody }
        """)

      case _ => c.abort(c.enclosingPosition, "Invalid annottee")
    }
  }
}

从根本上说,你所做的只是玩AST,而quasiquotes使得它非常简单。通过上述内容,我可以运行下面的代码并获得Bar(3)3的打印输出。

object Main extends App {
  val t = Bar(1)
  println(t.inc())
  println(t.two)
}

object Bar {
  def inc(b: Bar) = {
    val Bar(i) = b; Bar(i+2)
  }
  def two() = 3
}

@liftFromObject
case class Bar(i: Int)

请注意,由于这只是AST级操作,因此Scala的不同语法可能会有点脆弱,以不同的方式声明相同的内容......