Scala中的多态点积和匿名函数速记

时间:2011-11-28 16:40:18

标签: scala types polymorphism anonymous-function

我想通过以下方式在Scala中实现“矩阵点产品”:

type Real = Double
type Row = Array[Real]
type Matrix = Array[Row]

def dot[T](f: (T,T) => Real)(as: Iterable[T], bs: Iterable[T]): Real =
  (for ((a, b) <- as zip bs) yield f(a, b)) sum

def rowDot(r1: Row, r2: Row) = dot(_*_)(r1, r2)
def matDot(m1: Matrix, m2: Matrix) = dot(rowDot)(m1, m2)

但是,rowDot的定义不起作用。 Scala需要匿名函数(_*_)的显式类型注释,所以我必须编写

def rowDot(r1: Row, r2: Row) = dot((x:Real, y: Real) => x*y)(r1, r2)

def rowDot = dot((x:Real, y: Real) => x*y) _

有没有办法改变dot的定义,以便可以使用简写(_*_)

编辑:另一个困惑:matDot在某些情况下也会出现类型错误。它使用Arrays of Arrays失败,但没有使用Arrays of Lrays

scala> matDot(Array(Array(1.0,2.0)), Array(Array(1.0,2.0,3.0)))
<console>:27: error: type mismatch;
 found   : Array[Array[Double]]
 required: Iterable[Iterable[Real]]
              matDot(Array(Array(1.0,2.0)), Array(Array(1.0,2.0,3.0)))
                          ^

scala> matDot(List(Array(1.0,2.0)), List(Array(1.0,2.0,3.0)))
res135: Real = 5.0

有什么区别?

2 个答案:

答案 0 :(得分:3)

是 - 如果您切换参数列表。当函数参数在最后一个参数列表中单独使用时,函数参数的类型推断更有效:

def dot[T](as: Iterable[T], bs: Iterable[T])(f: (T,T) => Real): Real =
  (for ((a, b) <- as zip bs) yield f(a, b)) sum

def rowDot(r1: Row, r2: Row) = dot(r1, r2)(_*_)

答案 1 :(得分:3)

明确指定dot[Real]也应该有用。

def rowDot(r1: Row, r2: Row) = dot[Real](_*_)(r1, r2)

修改

回复您的修改:我认为问题是当您拥有Array时,不会递归地应用从WrappedArrayArray[Array]的隐式转换。

Array[Int]不是Iterable[Int];通常,当您将其分配给Iterable时,Array[Int]会隐式转换为WrappedArray[Int](其中WrappedArray 是Iterable [Int])。当您使用List[Array[Int]](隐式获得List[WrappedArray[Int]])时会发生这种情况。

但是,正如我所说,隐式转换不是递归应用的,因此Array[Array[Int]]不会隐式转换为WrappedArray[WrappedArray[Int]]

这是一个演示问题的REPL会话:

List [Array [Int]]可以分配给Iterable [Iterable [Int]](注意Array转换为WrappedArray)

scala> val i : Iterable[Iterable[Int]] = List(Array(1,2), Array(1,2,3))
i: Iterable[Iterable[Int]] = List(WrappedArray(1, 2), WrappedArray(1, 2, 3))

Array [Array [Int]]不会自动运行(如您所发现的)

scala> val j : Iterable[Iterable[Int]] = Array(Array(1,2), Array(1,2,3))
<console>:9: error: type mismatch;
 found   : Array[Array[Int]]
 required: Iterable[Iterable[Int]]
       val j : Iterable[Iterable[Int]] = Array(Array(1,2), Array(1,2,3))
                                              ^

但是,通过一些手持(手动将内部数组转换为WrappedArrays),一切都会再次运作:

    scala> import scala.collection.mutable.WrappedArray
    import scala.collection.mutable.WrappedArray

    scala> val k : Iterable[Iterable[Int]] = Array(WrappedArray.make(Array(1,2)),
 WrappedArray.make(Array(1,2,3)))
    k: Iterable[Iterable[Int]] = WrappedArray(WrappedArray(1, 2), WrappedArray(1, 2,
     3))