我想编写一个scala宏,它可以根据带有简单类型检查的map条目覆盖case类的字段值。 如果原始字段类型和覆盖值类型兼容,则设置新值,否则保留原始值。
到目前为止,我有以下代码:
import language.experimental.macros
import scala.reflect.macros.Context
object ProductUtils {
def withOverrides[T](entity: T, overrides: Map[String, Any]): T =
macro withOverridesImpl[T]
def withOverridesImpl[T: c.WeakTypeTag](c: Context)
(entity: c.Expr[T], overrides: c.Expr[Map[String, Any]]): c.Expr[T] = {
import c.universe._
val originalEntityTree = reify(entity.splice).tree
val originalEntityCopy = entity.actualType.member(newTermName("copy"))
val originalEntity =
weakTypeOf[T].declarations.collect {
case m: MethodSymbol if m.isCaseAccessor =>
(m.name, c.Expr[T](Select(originalEntityTree, m.name)), m.returnType)
}
val values =
originalEntity.map {
case (name, value, ctype) =>
AssignOrNamedArg(
Ident(name),
{
def reifyWithType[K: WeakTypeTag] = reify {
overrides
.splice
.asInstanceOf[Map[String, Any]]
.get(c.literal(name.decoded).splice) match {
case Some(newValue : K) => newValue
case _ => value.splice
}
}
reifyWithType(c.WeakTypeTag(ctype)).tree
}
)
}.toList
originalEntityCopy match {
case s: MethodSymbol =>
c.Expr[T](
Apply(Select(originalEntityTree, originalEntityCopy), values))
case _ => c.abort(c.enclosingPosition, "No eligible copy method!")
}
}
}
执行如下:
import macros.ProductUtils
case class Example(field1: String, field2: Int, filed3: String)
object MacrosTest {
def main(args: Array[String]) {
val overrides = Map("field1" -> "new value", "field2" -> "wrong type")
println(ProductUtils.withOverrides(Example("", 0, ""), overrides)) // Example("new value", 0, "")
}
}
正如您所看到的,我已设法获取原始字段的类型,现在想要在reifyWithType
中对其进行模式匹配。
不幸的是,在当前的实现中,我在编译期间收到警告:
warning: abstract type pattern K is unchecked since it is eliminated by erasure case Some(newValue : K) => newValue
和IntelliJ中的编译器崩溃:
Exception in thread "main" java.lang.NullPointerException
at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.preEraseAsInstanceOf$1(Erasure.scala:1032)
at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.preEraseNormalApply(Erasure.scala:1083)
at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.preEraseApply(Erasure.scala:1187)
at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.preErase(Erasure.scala:1193)
at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.transform(Erasure.scala:1268)
at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.transform(Erasure.scala:1018)
at scala.reflect.internal.Trees$class.itransform(Trees.scala:1217)
at scala.reflect.internal.SymbolTable.itransform(SymbolTable.scala:13)
at scala.reflect.internal.SymbolTable.itransform(SymbolTable.scala:13)
at scala.reflect.api.Trees$Transformer.transform(Trees.scala:2897)
at scala.tools.nsc.transform.TypingTransformers$TypingTransformer.transform(TypingTransformers.scala:48)
at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.transform(Erasure.scala:1280)
at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.transform(Erasure.scala:1018)
所以问题是:
*是否可以对宏中接收的类型进行类型比较以计算运行时类型?
*或者有没有更好的方法来解决这个任务?
答案 0 :(得分:0)
毕竟我最终得到了以下解决方案:
import language.experimental.macros
import scala.reflect.macros.Context
object ProductUtils {
def withOverrides[T](entity: T, overrides: Map[String, Any]): T =
macro withOverridesImpl[T]
def withOverridesImpl[T: c.WeakTypeTag](c: Context)(entity: c.Expr[T], overrides: c.Expr[Map[String, Any]]): c.Expr[T] = {
import c.universe._
val originalEntityTree = reify(entity.splice).tree
val originalEntityCopy = entity.actualType.member(newTermName("copy"))
val originalEntity =
weakTypeOf[T].declarations.collect {
case m: MethodSymbol if m.isCaseAccessor =>
(m.name, c.Expr[T](Select(c.resetAllAttrs(originalEntityTree), m.name)), m.returnType)
}
val values =
originalEntity.map {
case (name, value, ctype) =>
AssignOrNamedArg(
Ident(name),
{
val ruClass = c.reifyRuntimeClass(ctype)
val mtag = c.reifyType(treeBuild.mkRuntimeUniverseRef, Select(treeBuild.mkRuntimeUniverseRef, newTermName("rootMirror")), ctype)
val mtree = Select(mtag, newTermName("tpe"))
def reifyWithType[K: c.WeakTypeTag] = reify {
def tryNewValue[A: scala.reflect.runtime.universe.TypeTag](candidate: Option[A]): Option[K] =
if (candidate.isEmpty) {
None
} else {
val cc = c.Expr[Class[_]](ruClass).splice
val candidateValue = candidate.get
val candidateType = scala.reflect.runtime.universe.typeOf[A]
val expectedType = c.Expr[scala.reflect.runtime.universe.Type](mtree).splice
val ok = (cc.isPrimitive, candidateValue) match {
case (true, _: java.lang.Integer) => cc == java.lang.Integer.TYPE
case (true, _: java.lang.Long) => cc == java.lang.Long.TYPE
case (true, _: java.lang.Double) => cc == java.lang.Double.TYPE
case (true, _: java.lang.Character) => cc == java.lang.Character.TYPE
case (true, _: java.lang.Float) => cc == java.lang.Float.TYPE
case (true, _: java.lang.Byte) => cc == java.lang.Byte.TYPE
case (true, _: java.lang.Short) => cc == java.lang.Short.TYPE
case (true, _: java.lang.Boolean) => cc == java.lang.Boolean.TYPE
case (true, _: Unit) => cc == java.lang.Void.TYPE
case _ =>
val args = candidateType.asInstanceOf[scala.reflect.runtime.universe.TypeRefApi].args
if (!args.contains(scala.reflect.runtime.universe.typeOf[Any])
&& !(candidateType =:= scala.reflect.runtime.universe.typeOf[Any]))
candidateType =:= expectedType
else cc.isInstance(candidateValue)
}
if (ok)
Some(candidateValue.asInstanceOf[K])
else None
}
tryNewValue(overrides.splice.get(c.literal(name.decoded).splice)).getOrElse(value.splice)
}
reifyWithType(c.WeakTypeTag(ctype)).tree
}
)
}.toList
originalEntityCopy match {
case s: MethodSymbol =>
c.Expr[T](
Apply(Select(originalEntityTree, originalEntityCopy), values))
case _ => c.abort(c.enclosingPosition, "No eligible copy method!")
}
}
}
满足原始要求:
class ProductUtilsTest extends FunSuite {
case class A(a: String, b: String)
case class B(a: String, b: Int)
case class C(a: List[Int], b: List[String])
case class D(a: Map[Int, String], b: Double)
case class E(a: A, b: B)
test("simple overrides works"){
val overrides = Map("a" -> "A", "b" -> "B")
assert(ProductUtils.withOverrides(A("", ""), overrides) === A("A", "B"))
}
test("simple overrides works 1"){
val overrides = Map("a" -> "A", "b" -> 1)
assert(ProductUtils.withOverrides(B("", 0), overrides) === B("A", 1))
}
test("do not override if types do not match"){
val overrides = Map("a" -> 0, "b" -> List("B"))
assert(ProductUtils.withOverrides(B("", 0), overrides) === B("", 0))
}
test("complex types also works"){
val overrides = Map("a" -> List(1), "b" -> List("A"))
assert(ProductUtils.withOverrides(C(List(0), List("")), overrides) === C(List(1), List("A")))
}
test("complex types also works 1"){
val overrides = Map("a" -> List(new Date()), "b" -> 2.0d)
assert(ProductUtils.withOverrides(D(Map(), 1.0), overrides) === D(Map(), 2.0))
}
test("complex types also works 2"){
val overrides = Map("a" -> A("AA", "BB"), "b" -> 2.0d)
assert(ProductUtils.withOverrides(E(A("", ""), B("", 0)), overrides) === E(A("AA", "BB"), B("", 0)))
}
}
不幸的是,由于Java / Scala中的类型擦除,在将值更改为新值之前很难强制类型相等,因此您可以执行以下操作:
scala> case class C(a: List[Int], b: List[String])
defined class C
scala> val overrides = Map("a" -> List(new Date()), "b" -> List(1.0))
overrides: scala.collection.immutable.Map[String,List[Any]] = Map(a -> List(Mon Aug 26 15:52:27 CEST 2013), b -> List(1.0))
scala> ProductUtils.withOverrides(C(List(0), List("")), overrides)
res0: C = C(List(Mon Aug 26 15:52:27 CEST 2013),List(1.0))
scala> res0.a.head + 1
java.lang.ClassCastException: java.util.Date cannot be cast to java.lang.Integer
at scala.runtime.BoxesRunTime.unboxToInt(BoxesRunTime.java:106)
at .<init>(<console>:14)
at .<clinit>(<console>)
at .<init>(<console>:7)
at .<clinit>(<console>)
at $print(<console>)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:606)
at scala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:734)
at scala.tools.nsc.interpreter.IMain$Request.loadAndRun(IMain.scala:983)
at scala.tools.nsc.interpreter.IMain.loadAndRunReq$1(IMain.scala:573)
at scala.tools.nsc.interpreter.IMain.interpret(IMain.scala:604)