与Java相比,Scala dot产品非常慢

时间:2016-10-17 12:37:47

标签: java arrays scala math

我是Scala的新手,我想用相同的性能级别翻译我的Java代码。

给定n个浮点矢量和一个附加矢量,我必须计算所有n个点的乘积并获得最大值。

使用Java对我来说非常简单

public static void main(String[] args) {

    int N = 5000000;
    int R = 200;
    float[][] t = new float[N][R];
    float[] u = new float[R];

    Random r = new Random();

    for (int i = 0;i<N;i++) {
        for (int j = 0;j<R;j++) {
            if (i == 0) {
                u[j] = r.nextFloat();
            }
            t[i][j] = r.nextFloat();
        }
    }

    long ts = System.currentTimeMillis();
    float maxScore = -1.0f;

    for (int i = 0;i < N;i++) {
        float score = 0.0f;
        for (int j = 0; i < R;i++) {
            score += u[j] * t[i][j];
        }
        if (score > maxScore) {
            maxScore = score;
        }

    }

    System.out.println(System.currentTimeMillis() - ts);
    System.out.println(maxScore);

}

我的机器上的计算时间为6毫秒。

现在我必须使用Scala

val t = Array.ofDim[Float](N,R)
val u = Array.ofDim[Float](R)

// Filling with random floats like in Java

val ts = System.currentTimeMillis()
var maxScore: Float = -1.0f

for ( i <- 0 until N) {
  var score = 0.0f
  for (j <- 0 until R) {
    score += u(j) * t(i)(j)
  }
  if (score > maxScore) {
    maxScore = score
  }

}

println(System.currentTimeMillis() - ts)
println(maxScore);

上面的代码在我的机器上占用的时间超过秒。 我的想法是Scala没有原始数组结构,例如Java中的float [],并且被集合替换。索引i的访问速度似乎比Java中的原始数组慢。

以下代码甚至更慢:

val maxScore = t.map( r => r zip u map Function.tupled(_*_) reduceLeft (_+_)).max

需要26秒

我应该如何有效地迭代我的2个数组来计算它?

非常感谢

2 个答案:

答案 0 :(得分:22)

嗯,很遗憾地说,但这里奇怪的是你的Java实现有多快,而不是你的Scala速度有多慢 - 遍历100亿(!)单元的6ms听起来好得令人难以置信 - 事实上 - Java实现中有一个拼写错误,这使得此代码的功能更少:

而不是for (int j = 0; j < R;j++),你有for (int j = 0; i < R;i++) - 这使得内部循环只运行200次而不是10亿次 ......

如果你解决了这个问题 - Scala和Java的性能相当。

这是BTW,实际上是Scala的优势 - 让for (j <- 0 until R)错误更难:)

答案 1 :(得分:3)

真正的问题只是一个错字(就像Tzach Zohar提到的那样),但是如果你想提高性能,那么你可以用更多的方式来做:

var i = 0
while (i < N) {
  var j = 0
  var score = 0.0f
  val t1: Array[Float] = t(i)
  while (j < R) {
    score += u(j) * t1(j)
    j += 1
  }
  if (score > maxScore) {
    maxScore = score
  }

  i += 1
}

此代码段的运行速度比for-comprehension版快10-20%。

或者!你可以使用“par”使第一个数组平行并在map中使用while循环:

val maxScore = t.par.map({
  arr =>
    var score = 0.0f
    var j = 0
    while (j < R) {
      score += u(j) * arr(j)
      j += 1
    }
    score
}).max

这个代码在我的机器上运行速度比java版快2-3倍! 亲自尝试一下!祝你好运