矢量维度的编译时检查

时间:2011-02-12 14:19:39

标签: scala types

我在scala中实现了一些轻量级数学向量。我想在编译时使用类型系统来检查向量兼容性。例如,尝试将维度2的向量添加到另一个维度为3的向量时应该导致编译错误。

到目前为止,我将维度定义为案例类:

sealed trait Dim
case class One() extends Dim
case class Two() extends Dim
case class Three() extends Dim
case class Four() extends Dim
case class Five() extends Dim

这是矢量定义:

class Vec[D <: Dim](val values: Vector[Double]) {

  def apply(i: Int) = values(i)

  def *(k: Double) = new Vec[D]( values.map(_*k) )

  def +(that: Vec[D]) = {
    val newValues = ( values zip that.values ) map { 
      pair => pair._1 + pair._2
    }
    new Vec[D](newValues)
  }

  override lazy val toString = "Vec(" + values.mkString(", ") + ")"

}

此解决方案效果很好,但我有两个问题:

  • 如何添加返回维度的dimension():Int方法(即Vec[Three]为3)?

  • 如何在不事先声明所有必需案例类的情况下处理更高维度?

PS:我知道有很好的数学矢量库,我只是想提高我对scala的理解。

3 个答案:

答案 0 :(得分:3)

我的建议:

答案 1 :(得分:1)

我建议这样的事情:

sealed abstract class Dim(val dimension:Int)

object Dim {
  class One extends Dim(1)
  class Two extends Dim(2)
  class Three extends Dim(3)

  implicit object One extends One
  implicit object Two extends Two
  implicit object Three extends Three
}

case class Vec[D <: Dim](values: Vector[Double])(implicit dim:D) {

  require(values.size == dim.dimension)

  def apply(i: Int) = values(i)

  def *(k: Double) = Vec[D]( values.map(_*k) )

  def +(that: Vec[D]) = Vec[D](
     ( values zip that.values ) map {
      pair => pair._1 + pair._2
  })

  override lazy val toString = values.mkString("Vec(",", ",")")
}

当然,你只能通过运行时检查向量长度,但正如其他人已经指出的那样,你需要教会数字或其他类型级编程技术来实现编译时间检查。< / p>

  import Dim._
  val a = Vec[Two](Vector(1.0,2.0))
  val b = Vec[Two](Vector(1.0,3.0))
  println(a + b)
  //--> Vec(2.0, 5.0) 

  val c = Vec[Three](Vector(1.0,3.0)) 
  //--> Exception in thread "main" java.lang.ExceptionInInitializerError
  //-->        at scalatest.vecTest.main(vecTest.scala)
  //--> Caused by: java.lang.IllegalArgumentException: requirement failed

答案 2 :(得分:0)

如果您不希望沿着Peano路线走下去,您可以随时使用Vec构建D,然后使用该实例通过{{1}确定尺寸伴侣对象。例如:

Dim

我认为在选择时,您应该使用案例对象而不是案例类:

object Dim {
  def dimensionOf(d : Dim) = d match {
    case One => 1
    case Two => 2
    case Three => 3
  }
}
sealed trait Dim

然后在你的矢量上,你可能必须实际存储Dim:

case object One extends Dim
case object Two extends Dim