关于简单数学数组运算的新手Scala问题

时间:2010-09-07 05:31:15

标签: scala

Newbie Scala问题:

我想在Scala中执行此[Java代码]:

public static double[] abs(double[] r, double[] im) {
  double t[] = new double[r.length];
  for (int i = 0; i < t.length; ++i) {
    t[i] = Math.sqrt(r[i] * r[i] + im[i] * im[i]);
  }
  return t;
}  

并且也使它成为通用的(因为Scala有效地执行了我读过的通用基元)。仅依靠核心语言(没有库对象/类,方法等),如何做到这一点?说实话,我根本不知道怎么做,所以我猜这只是一个纯粹的奖励点问题。

我遇到了很多问题,试图做这个简单的事情,我现在放弃了Scala。希望一旦我看到斯卡拉的方式,我就会有一个“啊哈”的时刻。

更新: 与其他人讨论这个问题,这是我迄今为止找到的最佳答案。

def abs[T](r: Iterable[T], im: Iterable[T])(implicit n: Numeric[T]) = {
   import n.mkNumericOps                                                   
   r zip(im) map(t => math.sqrt((t._1 * t._1 + t._2 * t._2).toDouble))          
}

4 个答案:

答案 0 :(得分:13)

在scala中执行泛型/高性能原语实际上涉及scala用于避免装箱/拆箱的两种相关机制(例如,将int包裹在java.lang.Integer中,反之亦然) :

  • @specialize类型注释
  • Manifest与数组
  • 一起使用

specialize是一个注释,告诉Java编译器创建“原始”版本的代码(类似于C ++模板,所以我被告知)。查看Tuple2(专业)与List(不是)相比的类型声明。它是在 2.8 中添加的,意味着,例如像CC[Int].map(f : Int => Int)这样的代码执行时没有装箱任何int(假设CC是专门的,当然! )。

Manifest是在scala中执行 reified 类型的一种方式(受JVM的类型擦除限制)。当您希望在某种类型T上对某个方法进行通用化,然后在该方法中创建T(即T[])数组时,此功能尤其有用。在Java中,这是不可能的,因为new T[]是非法的。在scala中,这可以使用Manifest。特别是,在这种情况下,它允许我们构建一个原始 T阵列,如double[]int[]。 (这很棒,万一你想知道)

从性能角度来看,拳击是如此重要,因为它会产生垃圾,除非你的所有int都是&lt; 127.显然,它还增加了一个间接级别的额外处理步骤/方法调用等。但是考虑到你可能不会发出声音,除非你绝对肯定你确实这样做(即大多数代码没有需要这样的微观优化)


所以,回到问题:为了在没有装箱/拆箱的情况下执行此操作,您必须使用ArrayList尚未专业化,并且无论如何都会更加对象,甚至如果是的话!)。一对集合上的zipped函数将返回Tuple2 s的集合(不需要装箱,因为此专用的)。

为了通常这样做(即跨越各种数字类型),您必须要求在通用参数上绑定上下文Numeric并且可以找到Manifest(创建数组时需要) )。所以我开始沿着......

开始
def abs[T : Numeric : Manifest](rs : Array[T], ims : Array[T]) : Array[T] = {
    import math._
    val num = implicitly[Numeric[T]]
    (rs, ims).zipped.map { (r, i) => sqrt(num.plus(num.times(r,r), num.times(i,i))) }
    //                               ^^^^ no SQRT function for Numeric
}

...但不能正常工作。原因是“通用”Numeric值没有像sqrt - &gt;这样的操作。所以你只能在知道你有一个Double时这样做。例如:

scala> def almostAbs[T : Manifest : Numeric](rs : Array[T], ims : Array[T]) : Array[T] = {
 | import math._
 | val num = implicitly[Numeric[T]]
 | (rs, ims).zipped.map { (r, i) => num.plus(num.times(r,r), num.times(i,i)) }
 | }
almostAbs: [T](rs: Array[T],ims: Array[T])(implicit evidence$1: Manifest[T],implicit     evidence$2: Numeric[T])Array[T]

很好 - 现在看到这种纯粹的通用方法做了一些事情!

scala> val rs = Array(1.2, 3.4, 5.6); val is = Array(6.5, 4.3, 2.1)
rs: Array[Double] = Array(1.2, 3.4, 5.6)
is: Array[Double] = Array(6.5, 4.3, 2.1)

scala> almostAbs(rs, is)
res0: Array[Double] = Array(43.69, 30.049999999999997, 35.769999999999996)

现在我们可以sqrt结果,因为我们有Array[Double]

scala> res0.map(math.sqrt(_))
res1: Array[Double] = Array(6.609841147864296, 5.481788029466298, 5.980802621722272)

并证明即使使用其他Numeric类型也可以使用

scala> import math._
import math._
scala> val rs = Array(BigDecimal(1.2), BigDecimal(3.4), BigDecimal(5.6)); val is =     Array(BigDecimal(6.5), BigDecimal(4.3), BigDecimal(2.1))
rs: Array[scala.math.BigDecimal] = Array(1.2, 3.4, 5.6)
is: Array[scala.math.BigDecimal] = Array(6.5, 4.3, 2.1)

scala> almostAbs(rs, is)
res6: Array[scala.math.BigDecimal] = Array(43.69, 30.05, 35.77)

scala> res6.map(d => math.sqrt(d.toDouble))
res7: Array[Double] = Array(6.609841147864296, 5.481788029466299, 5.9808026217222725)

答案 1 :(得分:11)

使用zipmap

scala> val reals = List(1.0, 2.0, 3.0)
reals: List[Double] = List(1.0, 2.0, 3.0)

scala> val imags = List(1.5, 2.5, 3.5)
imags: List[Double] = List(1.5, 2.5, 3.5)

scala> reals zip imags
res0: List[(Double, Double)] = List((1.0,1.5), (2.0,2.5), (3.0,3.5))

scala> (reals zip imags).map {z => math.sqrt(z._1*z._1 + z._2*z._2)}
res2: List[Double] = List(1.8027756377319946, 3.2015621187164243, 4.6097722286464435)

scala> def abs(reals: List[Double], imags: List[Double]): List[Double] =
     | (reals zip imags).map {z => math.sqrt(z._1*z._1 + z._2*z._2)}
abs: (reals: List[Double],imags: List[Double])List[Double]

scala> abs(reals, imags)
res3: List[Double] = List(1.8027756377319946, 3.2015621187164243, 4.6097722286464435)

<强>更新

最好使用zipped,因为它可以避免创建临时集合:

scala> def abs(reals: List[Double], imags: List[Double]): List[Double] =
     | (reals, imags).zipped.map {(x, y) => math.sqrt(x*x + y*y)}
abs: (reals: List[Double],imags: List[Double])List[Double]

scala> abs(reals, imags)
res7: List[Double] = List(1.8027756377319946, 3.2015621187164243, 4.6097722286464435)

答案 2 :(得分:4)

Java中没有一种简单的方法可以创建通用的数字计算代码;从牛轭的答案中你可以看到图书馆不存在。集合也被设计为采用任意类型,这意味着使用它们的原语有一些开销。所以最快的代码(没有仔细的边界检查)是:

def abs(re: Array[Double], im: Array[Double]) = {
  val a = new Array[Double](re.length)
  var i = 0
  while (i < a.length) {
    a(i) = math.sqrt(re(i)*re(i) + im(i)*im(i))
    i += 1
  }
  a
}

或者,尾递归:

def abs(re: Array[Double], im: Array[Double]) = {
  def recurse(a: Array[Double], i: Int = 0): Array[Double] = {
    if (i < a.length) {
      a(i) = math.sqrt(re(i)*re(i) + im(i)*im(i))
      recurse(a, i+1)
    }
    else a
  }
  recurse(new Array[Double](re.length))
}

所以,不幸的是,这段代码最终看起来不太好看;一旦你将它打包在一个方便的复杂数字库中,它就会出现。

如果事实证明您实际上并不需要高效的代码,那么

def abs(re: Array[Double], im: Array[Double]) = {
  (re,im).zipped.map((i,j) => math.sqrt(i*i + j*j))
}

将会清楚地(在您了解zipped的工作原理后)以紧凑和概念的方式完成这项工作。我手中的惩罚是这大约慢了2倍。 (使用List使得它比我手中的while或tail递归慢7倍;带List的{​​{1}}使得它慢20倍;即使没有计算平方根,带有数组的泛型也会慢3倍。)

(编辑:修正时间以反映更典型的用例。)

答案 3 :(得分:1)

编辑后:

好的,我已经运行了我想做的事情。将取两个任何类型的数字列表并返回一个双打数组。

def abs[A](r:List[A], im:List[A])(implicit numeric: Numeric[A]):Array[Double] = {
  var t = new Array[Double](r.length)
  for( i <- r.indices) {          
    t(i) = math.sqrt(numeric.toDouble(r(i))*numeric.toDouble(r(i))+numeric.toDouble(im(i))*numeric.toDouble(im(i)))
  }
  t
}