如何使用两种参数类型在Scala中实现泛型函数?

时间:2015-06-11 02:29:29

标签: scala generics

我想在Scala中实现一个函数来计算两个数字序列的点积,如下所示

val x = Seq(1,2,3.0)
val y = Seq(4,5,6)
val z = (for (a <- x; b <- y) yield a*b).sum
scala> z  : Double = 90.0

val x = Seq(1,2,3)
val y = Seq(4,5,6)
val z = (for (a <- x; b <- y) yield a*b).sum
scala> z  : Int = 90

请注意,如果两个序列的类型不同,则结果为Double。如果两个序列具有相同的类型(例如Int),则结果为Int。

我提出了两种选择,但都不符合上面定义的要求。

备选方案#1:

def dotProduct[T: Numeric](x: Seq[T], y: Seq[T]): T = (for (a <- x; b <- y) yield implicitly[Numeric[T]].times(a, b)).sum

这将返回与输入相同类型的结果,但不能采用两种不同的类型。

备选方案#2:

def dotProduct[A, B](x: Seq[A], y: Seq[B])(implicit nx: Numeric[A], ny: Numeric[B]) = (for (a <- x; b <- y) yield nx.toDouble(a)*ny.toDouble(b)).sum

这适用于所有数字序列。但是,它总是返回一个Double,即使这两个序列是Int。

类型

非常感谢任何建议。

P.S。我上面实现的功能不是“点积”,而只是两个序列的乘积之和。谢谢Daniel指出来。

备选方案#3(略好于备选方案#1和#2):

def sumProduct[T, A <% T, B <% T](x: Seq[A], y: Seq[B])(implicit num: Numeric[T]) = (for (a <- x; b <- y) yield num.times(a,b)).sum

sumProduct(Seq(1,2,3), Seq(4,5,6))  //> res0: Int = 90
sumProduct(Seq(1,2,3.0), Seq(4,5,6))  //> res1: Double = 90.0
sumProduct(Seq(1,2,3), Seq(4,5,6.0))  // Fails!!!

不幸的是,Scala 2.10中将弃用View Bound功能(例如“&lt;%”)。

1 个答案:

答案 0 :(得分:1)

您可以创建一个代表促销规则的类型类:

trait NumericPromotion[A, B, C] {
  def promote(a: A, b: B): (C, C)
}

implicit object IntDoublePromotion extends NumericPromotion[Int, Double, Double] {
  def promote(a: Int, b: Double): (Double, Double) = (a.toDouble, b)
}

def dotProduct[A, B, C]
              (x: Seq[A], y: Seq[B])
              (implicit numEv: Numeric[C], promEv: NumericPromotion[A, B, C])
              : C = {
  val foo = for {
    a <- x
    b <- y
  } yield {
    val (pa, pb) = promEv.promote(a, b)
    numEv.times(pa, pb)
  }

  foo.sum
}

dotProduct[Int, Double, Double](Seq(1, 2, 3), Seq(1.0, 2.0, 3.0))

我的类型类fu不足以消除对dotProduct的调用中的显式类型参数,也无法弄清楚如何避免方法中的val foo;内联foo导致编译错误。我认为没有真正内化隐式解决规则。也许其他人可以让你更进一步。

值得一提的是,这是方向性的;你无法计算dotProduct(Seq(1.0, 2.0, 3.0), Seq(1, 2, 3))。但这很容易解决:

implicit def flipNumericPromotion[A, B, C]
                                 (implicit promEv: NumericPromotion[B, A, C])
                                 : NumericPromotion[A, B, C] = 
  new NumericPromotion[A, B, C] {
    override def promote(a: A, b: B): (C, C) = promEv.promote(b, a)
  }

还值得一提的是,您的代码并未计算点积。 [1, 2, 3][4, 5, 6]的点积为4 + 10 + 18 = 32