重写和推送特征的注释宏,未正确处理泛型

时间:2015-12-01 15:18:07

标签: scala-macros

我正在编写一个宏,需要创建一个重写特征的类,具有相同的trait方法/ args但返回类型不同。

所以说我们得到了:

trait MyTrait[T]
{
  def x(t1: T)(t2: T): T
}

@AnnProxy
class MyClass[T] extends MyTrait[T]

MyClass将被重写为:

class MyClass[T] {
 def x(t1: T)(t2: T): R[T]
}

(所以x现在将返回R [T]而不是T)

我写了宏并调试它,它生成了这段代码:

Expr[Any](class MyClass[T] extends scala.AnyRef {
   def <init>() = {
     super.<init>();
     ()
   };
   def x(t1: T)(t2: T): macrotests.R[T] = $qmark$qmark$qmark
 })
@AnnProxy

如您所见,签名似乎没问题。但是在尝试使用宏时,我收到了编译错误:

    val my = new MyClass[Int]
    my.x(5)(6)

错误:(14,7)类型不匹配;  发现:Int(5)  要求:T         x.x中(5)(6)             ^

所以看起来方法的泛型T与类[T]不同。任何想法如何解决?

到目前为止,这是我的宏。我对宏没有任何好处(通过stackoverflow提供了很多帮助),但这是当前状态:

@compileTimeOnly("enable macro paradise to expand macro annotations")
class AnnProxy extends StaticAnnotation
{
    def macroTransform(annottees: Any*): Any = macro IdentityMacro.impl
}

trait R[T]

object IdentityMacro
{

private val SDKClasses = Set("java.lang.Object", "scala.Any")

def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
    import c.universe._

    def showInfo(s: String) = c.info(c.enclosingPosition, s.split("\n").mkString("\n |---macro info---\n |", "\n |", ""), true)

    val classDef = annottees.map(_.tree).head.asInstanceOf[ClassDef]
    val clazz = c.typecheck(classDef).symbol.asClass
    val tparams = clazz.typeParams
    val baseClasses = clazz.baseClasses.tail.filter(clz => !SDKClasses(clz.fullName))
    val methods = baseClasses.flatMap {
        base =>
            base.info.decls.filter(d => d.isMethod && d.isPublic).map { decl =>
                val termName = decl.name.toTermName
                val method = decl.asMethod
                val params = method.paramLists.map(_.map {
                    s =>
                        val vd = internal.valDef(s)

                        val f = tparams.find(_.name == vd.tpt.symbol.name)
                        val sym = if (f.nonEmpty) f.get else vd.tpt.symbol

                        q"val ${vd.name} : $sym "
                })
                val paramVars = method.paramLists.flatMap(_.map(_.name))

                q""" def $termName (...$params)(timeout:scala.concurrent.duration.FiniteDuration) : macrotests.R[${method.returnType}] = {
            ???
           }"""
            }
    }

    val cde = c.Expr[Any] {
        q"""
      class ${classDef.name} [..${classDef.tparams}] {
        ..$methods
      }
  """
    }
    showInfo(show(cde))
    cde
}
}
编辑:我设法通过将类构建为字符串然后使用c.parse进行编译来解决它。感觉像黑客,但它的工作原理。通过操纵树必须有更好的方法。

package macrotests

import scala.annotation.{StaticAnnotation, compileTimeOnly}
import scala.language.experimental.macros
import scala.reflect.macros.whitebox

@compileTimeOnly("enable macro paradise to expand macro annotations")
class AnnProxy extends StaticAnnotation
{
    def macroTransform(annottees: Any*): Any = macro AnnProxyMacro.impl
}

trait R[T]

trait Remote[T]

object AnnProxyMacro
{

private val SDKClasses = Set("java.lang.Object", "scala.Any")

def impl(c: whitebox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
    import c.universe._

    val classDef = annottees.map(_.tree).head.asInstanceOf[ClassDef]
    val clazz = c.typecheck(classDef).symbol.asClass

    val baseClasses = clazz.baseClasses.tail.filter(clz => !SDKClasses(clz.fullName))
    val methods = baseClasses.flatMap {
        base =>
            base.info.decls.filter(d => d.isMethod && d.isPublic).map { decl =>
                val termName = decl.name.toTermName
                val method = decl.asMethod
                val params = method.paramLists.map(_.map {
                    s =>
                        val vd = internal.valDef(s)
                        val tq = vd.tpt
                        s"${vd.name} : $tq"
                })
                val paramVars = method.paramLists.flatMap(_.map(_.name))
                val paramVarsArray = paramVars.mkString("Array(", ",", ")")


                val paramsStr = params.map(_.mkString("(", ",", ")")).mkString(" ")
                val retTpe = method.returnType.typeArgs.mkString("-unexpected-")
                s""" def $termName $paramsStr (timeout:scala.concurrent.duration.FiniteDuration) : macrotests.Remote[$retTpe] = {
      println($paramVarsArray.toList)
            new macrotests.Remote[$retTpe] {}
           }"""
            }
    }

    val tparams = clazz.typeParams.map(_.name)
    val tparamsStr = if (tparams.isEmpty) "" else tparams.mkString("[", ",", "]")
    val code =
        s"""
           |class ${classDef.name}$tparamsStr (x:Int) {
           |${methods.mkString("\n")}
           |}
         """.stripMargin
    //      print(code)
    val cde = c.Expr[Any](c.parse(code))
    cde
}
}

1 个答案:

答案 0 :(得分:2)

代码很长,您可以查看github:https://github.com/1178615156/scala-macro-example/blob/master/stackoverflow/src/main/scala/so/AnnotationWithTrait.scala

import scala.annotation.StaticAnnotation
import scala.language.experimental.macros
import scala.reflect.macros.blackbox.Context

/**
  * Created by yu jie shui on 2015/12/2.
  */

class AnnotationWithTrait extends StaticAnnotation {
  def macroTransform(annottees: Any*): Any = macro AnnotationWithTraitImpl.apply

}

class AnnotationWithTraitImpl(val c: Context) {

  import c.universe._

  val SDKClasses = Set("java.lang.Object", "scala.Any")

  def showInfo(s: String) = c.info(c.enclosingPosition, s.split("\n").mkString("\n |---macro info---\n |", "\n |", ""), true)

  def apply(annottees: c.Expr[Any]*) = {

    val classDef = annottees.map(_.tree).head.asInstanceOf[ClassDef]

    val superClassSymbol= c.typecheck(classDef).symbol.asClass.baseClasses.tail
      .filterNot(e => SDKClasses.contains(e.fullName)).reverse

    val superClassTree= classDef match {
      case q"$mod class $name[..$t](..$params) extends ..$superClass { ..$body }" =>
        (superClass: List[Tree]).filterNot(e =>
          typeOf[Object].members.exists(_.name == e.children.head.toString())
        )
    }

    showInfo(show(superClassSymbol))
    showInfo(show(superClassTree))

    val impl = q"private[this] object ${TermName("impl")} extends ..${superClassTree}"
    //

    //get super class all can call method
    val methods = superClassSymbol.map(_.info.members
      .filterNot(_.isConstructor)
      .filterNot(e => typeOf[Object].members.exists(_.name == e.name)).map(_.asMethod)).toList

    case class ReplaceTypeParams(from: String, to: String)
    type ClassReplace = List[ReplaceTypeParams]

    //trait a[A]
    //class b[B] extends a[B]
    //need replace type params A to B
    val classReplaceList: List[ClassReplace] = superClassTree zip superClassSymbol map {
      case (superClassTree, superClassSymbol) =>
        superClassSymbol.asClass.typeParams.map(_.name) zip superClassTree.children.tail map
          (e => ReplaceTypeParams(e._1.toString, e._2.toString()))
    }

    val out = classReplaceList zip methods map {
      case (classReplace, func) =>

        func map { e => {

          val funcName = e.name

          val funcTypeParams = e.typeParams.map(_.name.toString).map(name => {
            TypeDef(Modifiers(Flag.PARAM), TypeName(name), List(), TypeBoundsTree(EmptyTree, EmptyTree))
          })

          val funcParams = e.paramLists.map(_.map(e => q"${e.name.toTermName}:${
            TypeName(
              classReplace.find(_.from == e.info.toString).map(_.to).getOrElse(e.info.toString)
            )} "))

          val funcResultType = TypeName(
            classReplace.find(_.from == e.returnType.toString).map(_.to).getOrElse(e.info.toString)
          )
          q"""
           def ${funcName}[..${funcTypeParams}](...$funcParams):${funcResultType}=
              impl.${funcName}[..${funcTypeParams}](...$funcParams)
            """
        }
        }

    }

    showInfo(show(out))

    q"""
       class ${classDef.name}[..${classDef.tparams}]{
        $impl
        ..${out.flatten}
       }
      """
  }
}

测试

trait MyTrait[MT1] {

  def x(t1: MT1)(t2: MT1): MT1 = t1

}

trait MyTrait2[MT2] {
  def t(t2: MT2): MT2 = t2
}


@AnnotationWithTrait
class MyClass[MCT1, MCT2] extends MyTrait[MCT1] with MyTrait2[MCT2]

object AnnotationWithTraitUsing extends App {
  assert(new MyClass[Int, String].x(1)(2) == 1)
  assert(new MyClass[Int, String].t("aaa") == "aaa")
}