我是Scala的新手,我想用相同的性能级别翻译我的Java代码。
给定n个浮点矢量和一个附加矢量,我必须计算所有n个点的乘积并获得最大值。
使用Java对我来说非常简单
public static void main(String[] args) {
int N = 5000000;
int R = 200;
float[][] t = new float[N][R];
float[] u = new float[R];
Random r = new Random();
for (int i = 0;i<N;i++) {
for (int j = 0;j<R;j++) {
if (i == 0) {
u[j] = r.nextFloat();
}
t[i][j] = r.nextFloat();
}
}
long ts = System.currentTimeMillis();
float maxScore = -1.0f;
for (int i = 0;i < N;i++) {
float score = 0.0f;
for (int j = 0; i < R;i++) {
score += u[j] * t[i][j];
}
if (score > maxScore) {
maxScore = score;
}
}
System.out.println(System.currentTimeMillis() - ts);
System.out.println(maxScore);
}
我的机器上的计算时间为6毫秒。
现在我必须使用Scala
val t = Array.ofDim[Float](N,R)
val u = Array.ofDim[Float](R)
// Filling with random floats like in Java
val ts = System.currentTimeMillis()
var maxScore: Float = -1.0f
for ( i <- 0 until N) {
var score = 0.0f
for (j <- 0 until R) {
score += u(j) * t(i)(j)
}
if (score > maxScore) {
maxScore = score
}
}
println(System.currentTimeMillis() - ts)
println(maxScore);
上面的代码在我的机器上占用的时间超过秒。 我的想法是Scala没有原始数组结构,例如Java中的float [],并且被集合替换。索引i的访问速度似乎比Java中的原始数组慢。
以下代码甚至更慢:
val maxScore = t.map( r => r zip u map Function.tupled(_*_) reduceLeft (_+_)).max
需要26秒
我应该如何有效地迭代我的2个数组来计算它?
非常感谢
答案 0 :(得分:22)
嗯,很遗憾地说,但这里奇怪的是你的Java实现有多快,而不是你的Scala速度有多慢 - 遍历100亿(!)单元的6ms听起来好得令人难以置信 - 事实上 - Java实现中有一个拼写错误,这使得此代码的功能更少:
而不是for (int j = 0; j < R;j++)
,你有for (int j = 0; i < R;i++)
- 这使得内部循环只运行200次而不是10亿次 ......
如果你解决了这个问题 - Scala和Java的性能相当。
这是BTW,实际上是Scala的优势 - 让for (j <- 0 until R)
错误更难:)
答案 1 :(得分:3)
真正的问题只是一个错字(就像Tzach Zohar提到的那样),但是如果你想提高性能,那么你可以用更多的方式来做:
var i = 0
while (i < N) {
var j = 0
var score = 0.0f
val t1: Array[Float] = t(i)
while (j < R) {
score += u(j) * t1(j)
j += 1
}
if (score > maxScore) {
maxScore = score
}
i += 1
}
此代码段的运行速度比for-comprehension版快10-20%。
或者!你可以使用“par”使第一个数组平行并在map中使用while循环:
val maxScore = t.par.map({
arr =>
var score = 0.0f
var j = 0
while (j < R) {
score += u(j) * arr(j)
j += 1
}
score
}).max
这个代码在我的机器上运行速度比java版快2-3倍! 亲自尝试一下!祝你好运