正确使用包装类的继承

时间:2015-09-30 16:01:48

标签: scala

我正在尝试使用ejml库为矩阵运算编写一个scala-wrapper。基本上我只使用SimpleMatrix。但是,我想要矩阵和向量的不同类,例如,只能反转矩阵或显式声明函数返回向量,而不是矩阵。目前,我无法返回具体类而不是特性。

我从一个特质开始,MLMatrixLike:

trait MLMatrixLike {
  def data: SimpleMatrix
  protected def internalMult(implicit that: MLMatrixLike): SimpleMatrix = {
    data.mult(that.data)
  }
  def *(implicit that: MLMatrixLike): MLVector = MLVector(internalMult)
}

我的矩阵类和我的vector类都在扩展特性:

case class MLMatrix(data: SimpleMatrix) extends MLMatrixLike {

  def this(rawData: Array[Array[Double]]) = this(new SimpleMatrix(rawData))

  def apply(row: Int, col:Int): Double = data.get(row, col)

  def transpose(): MLMatrix = MLMatrix(data.transpose())

  def invert(): MLMatrix = MLMatrix(data.invert())

  def *(implicit that: MLMatrix): MLMatrix = MLMatrix(internalMult)

  def *(that: Double): MLMatrix = MLMatrix(data.scale(that))

  def -(that: MLMatrix): MLMatrix = MLMatrix(data.minus(that.data))
}

object MLMatrix {
  def apply(rawData: Array[Array[Double]]) = new MLMatrix(rawData)
}

case class MLVector(data: SimpleMatrix) extends MLMatrixLike {

  def this(rawData: Array[Double]) = {
    this(new SimpleMatrix(Array(rawData)).transpose())
  }

  def apply(index: Int): Double = data.get(index)

  def transpose(): MLVector = MLVector(data.transpose())

  def -(that: MLVector): MLVector = MLVector(data.minus(that.data))
}

object MLVector {
  def apply(rawData: Array[Double]) = new MLVector(rawData)
}

在我看来,这种设置并不是很好。我想只定义乘法(*)一次,因为SimpleMatrix调用总是相同的,我可以从参数的类型推断出#34;"返回类型应该是矩阵还是向量。因此,我想在MLMatrixLike中沿着这个(不工作)函数的行定义一个函数:

def *[T <: MLMatrixLike](that :T): T = {
  new T(data.mult(that.data))
}

当然,这不起作用,因为没有这样的构造函数T,但目前我没有看到,我怎么能得到类似工作的东西。返回MLMatrixLike是不正确的,因为我在编译期间无法检查是否返回了正确的类型。

类似的问题适用于转置和减号 - 这里返回类型始终是自己的类。

非常感谢!

1 个答案:

答案 0 :(得分:1)

我不确定在其他两个类中包装SimpleMatrix的好处是什么。但是,您可以通过使MLMatrixLike通用其自身类型并定义抽象构造函数来解决重复问题。

trait MLMatrixLike[Self <: MLMatrixLike[Self]] {
  this: Self =>
  def data: SimpleMatrix

  def createNew(data: SimpleMatrix): Self

  def *[T <: MLMatrixLike[T]](that: T): T = that.createNew(data.mult(that.data))

  def *(that: Double): Self = createNew(data.scale(that))

  def -(that: Self): Self = createNew(data.minus(that.data))

  def transpose: Self = createNew(data.transpose())
}

case class MLMatrix(data: SimpleMatrix) extends MLMatrixLike[MLMatrix] {
  this: MLMatrix =>

  def this(rawData: Array[Array[Double]]) = this(new SimpleMatrix(rawData))

  override def createNew(data: SimpleMatrix): MLMatrix = MLMatrix(data)

  def apply(row: Int, col: Int): Double = data.get(row, col)

  def invert(): MLMatrix = MLMatrix(data.invert())

}

object MLMatrix {
  def apply(rawData: Array[Array[Double]]) = new MLMatrix(rawData)
}

case class MLVector(data: SimpleMatrix) extends MLMatrixLike[MLVector] {
  this: MLVector =>

  def this(rawData: Array[Double]) = {
    this(new SimpleMatrix(Array(rawData)).transpose())
  }

  override def createNew(data: SimpleMatrix): MLVector = MLVector(data)

  def apply(index: Int): Double = data.get(index)

}

object MLVector {
  def apply(rawData: Array[Double]) = new MLVector(rawData)
}

顺便说一下,请注意,列向量乘以行向量是一个矩阵,因此乘法的签名可能不会返回that的类型。但是,基于静态信息(您需要知道两个参数的维度),您无法判断乘法是返回向量还是矩阵,因此您也可以返回MLMatrixLike