如何减少泛型方法中不需要的类型参数?

时间:2016-10-28 04:46:21

标签: scala generics

我想用一些灵活的方法实现一些通用的数学函数。 例如一个名为meandot的函数,声明为

object Calc {
  def meandot[..](xs: Array[Left], ys: Array[Right])(implicit ..): Result
}

其中meandot(xs, ys) = sum(x*y for x, y in zip(xs, ys)) / length

当我调用meandot没有专门的类型参数时,它应该返回一个默认类型的值。 e.g。

scala> Calc.meandot(Array(1, 2), Array(1, 1))
res0: Int = 1

如果我使用专门的类型参数调用meandot,它可以返回正确的值。

scala> Calc.meandot[Int, Int, Double](Array(1, 2), Array(1, 1))
res1: Double = 1.5

但是,上面的前两个类型参数是多余的。我需要专门的唯一类型是返回类型。我想将其简化为

scala> Calc.meandot2(Array(1, 2), Array(1, 1))
res2: Int = 1

scala> Calc.meandot2[Double](Array(1, 2), Array(1, 1))
res3: Double = 1.5

我找到了一种方法来实现它,如下面的代码,它使用代理类MeanDotImp但似乎并不那么优雅。所以我想知道是否有更好的解决方案来减少泛型方法中不需要的类型参数?

trait Times[L, R, N] {
  def times(x: L, y: R): N
}

trait Num[N] {
  def zero: N = fromInt(0)
  def one:  N = fromInt(1)
  def fromInt(i: Int): N
  def plus(x: N, y: N): N
  def div(x: N, y: N): N
}

abstract class LowTimesImplicits {
  implicit val IID: Times[Int, Int, Double] = new Times[Int, Int, Double] {
    def times(x: Int, y: Int): Double = x * y
  }
}

object Times extends LowTimesImplicits {
  implicit val III: Times[Int, Int, Int] = new Times[Int, Int, Int] {
    def times(x: Int, y: Int): Int = x * y
  }
}

object Num {
  implicit val INT: Num[Int] = new Num[Int] {
    def fromInt(i: Int): Int = i
    def plus(x: Int, y: Int): Int = x + y
    def div(x: Int, y: Int): Int = x / y
  }

  implicit val DOU: Num[Double] = new Num[Double] {
    def fromInt(i: Int): Double = i
    def plus(x: Double, y: Double): Double = x + y
    def div(x: Double, y: Double): Double = x / y
  }
}

object Calc {
  def meandot[L, R, N](xs: Array[L], ys: Array[R])
             (implicit t: Times[L, R, N], n: Num[N]): N = {
    val total = (xs, ys).zipped.foldLeft(n.zero){
           case(r, (x, y)) => n.plus(r, t.times(x, y))
        }
    n.div(total, n.fromInt(xs.length))
  }

  implicit class MeanDotImp[L, R](val marker: Calc.type) {
    def meandot2[N](xs: Array[L], ys: Array[R])
                (implicit t: Times[L, R, N], n: Num[N]): N = {
      val total = (xs, ys).zipped.foldLeft(n.zero){
            case(r, (x, y)) => n.plus(r, t.times(x, y))
          }
      n.div(total, n.fromInt(xs.length))
    }
  }
}

1 个答案:

答案 0 :(得分:2)

替代解决方案与您的解决方案类似,但更简单一点:它首先修复您希望能够设置的类型参数,然后推断其他两个。为此,我们可以使用class meandot[N] { def apply[L, R](xs: Array[L], ys: Array[R]) (implicit t: Times[L, R, N], n: Num[N]): N = ??? // your implementation } 方法声明一个类:

new meandot

现在,为了避免编写object Calc { def meandot[N]: meandot[N] = new meandot[N] } ,我们可以定义一个只实例化这个类的方法:

scala> Calc.meandot(Array(1,2,3), Array(4,5,6))
res0: Int = 10

scala> Calc.meandot[Double](Array(1,2,3), Array(4,5,6))
res1: Double = 10.666666666666666

这种方法的优雅是有争议的,但它很简单,并不涉及隐含。这是一个用法演示:

A = A + D;