如何避免类型参数擦除

时间:2018-01-09 13:15:56

标签: scala types pattern-matching

我的代码使用Scala和Breeze线性代数库。我有DenseVector[Double]DenseVector[Int]等类型的对象...其中DenseVector是一个类似数组的容器,具有专用的数值计算方法。我有时需要对包含的类型使用模式匹配。类型擦除迫使我引入一个特征和"包装"案例类:

sealed trait DenseVectorRoot
case class DenseVectorReal(val data: DenseVector[Real]) extends DenseVectorRoot
case class DenseVectorMatrixReal(val data: DenseVector[DenseMatrix[Real]]) extends DenseVectorRoot

(其中Real只是Double的别名)。

模式匹配看起来像:

def print(data: DenseVectorRoot) =
  data match {
    case DenseVectorMatrixReal(_) => println("Contains real matrices")
    case DenseVectorReal(_) => println("Contains real scalars")
  }

我想摆脱DenseVectorRoot特质。我试过这个:

def print2(data: DenseVector[_ <: Any]) =
  data match {
    case _: DenseVector[Double] => println("Contains real matrices")
    case _: DenseVector[Int] => println("Contains real scalars")
  }

但类型参数会被删除。

我应该如何使用ClassTags修改print2以便模式匹配有效?例如,通过以下代码打印正确的输出:

val v0 = DenseVector(1.2, 1.5, 1.6)
val v1 = DenseVector(3, 4, 5)

val a = Array(v0, v1)
a.map(print2)

修改

我需要管理具有不同容器的Array的主要原因是我的代码需要管理各种类型的数据(例如,解析输入对于DenseVector[Real]DenseVector[Matrix[Real]])。我目前的设计是将所有内容存储在Array[DenseVectorRoot]中,然后使用.map()等高阶函数处理数据。在元素到元素的基础上,每个函数都将进行模式匹配,以了解数据是DenseVectorReal还是DenseVectorMatrixReal,并采取相应的行动。

这可能不是解决我的问题的最佳设计,但我不知道在编译时用户提供什么类型的数据。我很乐意知道更好的设计!

2 个答案:

答案 0 :(得分:1)

使用类型类更好地完成这类事情:

  trait DenseContent[T] {
    def compute(v: DenseVector[T]): String
  }
  object DenseContent {
    implicit object _Real extends DenseContent[Real] {
      def compute(v: DenseVector[Real]) = "real"
    }
    implicit object _Int extends DenseContent[Int] {
      def compute(v: DenseVector[Int]) = "int"
    }
    // etc ...
  }

  def print2[T : DenseContent](data: DenseVector[T]) = println(
     implicitly[DenseContent[T]].compute(data)
  )

答案 1 :(得分:1)

使用TypeTag

您可以请求编译器推断参数类型,并为您生成TypeTag。 然后,您可以使用TypeTag检查某种类型,也可以print进行调试。

示例代码

import scala.reflect.runtime.universe._

def printType[A: TypeTag](a: List[A]): Unit = 
  println(if(typeTag[A] == typeTag[Double]) "Double" else "other")

printType(List(1.0))
printType(List(1))

输出

>Double
>other