我对Rust不太熟悉,而且我正在尝试诊断性能问题。下面是一个相当快的Java代码(在7秒内运行)以及我认为应该是等效的Rust代码。但是,Rust代码运行得非常慢(是的,我也用--release
编译它),它似乎也溢出了。将i32
更改为i64
只是稍后推送溢出,但它仍然会发生。我怀疑我写的内容有一些错误,但在长时间盯着这个问题后,我决定寻求帮助。
public class Blah {
static final int N = 100;
static final int K = 50;
public static void main(String[] args) {
//initialize S
int[] S = new int[N];
for (int n = 1; n <= N; n++) S[n-1] = n*n;
// compute maxsum and minsum
int maxsum = 0;
int minsum = 0;
for (int n = 0; n < K; n++) {
minsum += S[n];
maxsum += S[N-n-1];
}
// initialize x and y
int[][] x = new int[K+1][maxsum+1];
int[][] y = new int[K+1][maxsum+1];
y[0][0] = 1;
// bottom-up DP over n
for (int n = 1; n <= N; n++) {
x[0][0] = 1;
for (int k = 1; k <= K; k++) {
int e = S[n-1];
for (int s = 0; s < e; s++) x[k][s] = y[k][s];
for (int s = 0; s <= maxsum-e; s++) {
x[k][s+e] = y[k-1][s] + y[k][s+e];
}
}
int[][] t = x;
x = y;
y = t;
}
// sum of unique K-subset sums
int sum = 0;
for (int s = minsum; s <= maxsum; s++) {
if (y[K][s] == 1) sum += s;
}
System.out.println(sum);
}
}
extern crate ndarray;
use ndarray::prelude::*;
use std::mem;
fn main() {
let numbers: Vec<i32> = (1..101).map(|x| x * x).collect();
let deg: usize = 50;
let mut min_sum: usize = 0;
for i in 0..deg {
min_sum += numbers[i] as usize;
}
let mut max_sum: usize = 0;
for i in deg..numbers.len() {
max_sum += numbers[i] as usize;
}
// Make an array
let mut x = OwnedArray::from_elem((deg + 1, max_sum + 1), 0i32);
let mut y = OwnedArray::from_elem((deg + 1, max_sum + 1), 0i32);
y[(0, 0)] = 1;
for n in 1..numbers.len() + 1 {
x[(0, 0)] = 1;
println!("Completed step {} out of {}", n, numbers.len());
for k in 1..deg + 1 {
let e = numbers[n - 1] as usize;
for s in 0..e {
x[(k, s)] = y[(k, s)];
}
for s in 0..max_sum - e + 1 {
x[(k, s + e)] = y[(k - 1, s)] + y[(k, s + e)];
}
}
mem::swap(&mut x, &mut y);
}
let mut ans = 0;
for s in min_sum..max_sum + 1 {
if y[(deg, s)] == 1 {
ans += s;
}
}
println!("{}", ans);
}
答案 0 :(得分:2)
一般来说,为了诊断性能问题,我:
最后一步是困难的部分。
在您的情况下,您可以使用单独的实现作为基准。比较两种实现,我们可以看到您的数据结构不同。在Java中,您正在构建嵌套数组,但在Rust中您使用的是ndarray包。我知道crate有一个很好的维护者,但我个人对它的内部知识一无所知,或者它最适合的用例。
所以我使用标准库Vec
重写了它。
我知道的另一件事是直接数组访问不像使用迭代器那样快。这是因为数组访问需要执行边界检查,而迭代器烘焙边界检查自己。很多时候,这意味着使用Iterator
上的方法。
另一个变化是尽可能执行批量数据传输。而不是逐个元素地复制,使用copy_from_slice
等方法移动整个切片。
通过这些更改,代码看起来像这样(为变量名称不好道歉,我确定你可以为它们提出语义名称):
use std::mem;
const N: usize = 100;
const DEGREE: usize = 50;
fn main() {
let numbers: Vec<_> = (1..N+1).map(|v| v*v).collect();
let min_sum = numbers[..DEGREE].iter().fold(0, |a, &v| a + v as usize);
let max_sum = numbers[DEGREE..].iter().fold(0, |a, &v| a + v as usize);
// different data types for x and y!
let mut x = vec![vec![0; max_sum+1]; DEGREE+1];
let mut y = vec![vec![0; max_sum+1]; DEGREE+1];
y[0][0] = 1;
for &e in &numbers {
let e2 = max_sum - e + 1;
let e3 = e + e2;
x[0][0] = 1;
for k in 0..DEGREE {
let current_x = &mut x[k+1];
let prev_y = &y[k];
let current_y = &y[k+1];
// bulk copy
current_x[0..e].copy_from_slice(¤t_y[0..e]);
// more bulk copy
current_x[e..e3].copy_from_slice(&prev_y[0..e2]);
// avoid array index
for (x, y) in current_x[e..e3].iter_mut().zip(¤t_y[e..e3]) {
*x += *y;
}
}
mem::swap(&mut x, &mut y);
}
let sum = y[DEGREE][min_sum..max_sum+1].iter().enumerate().filter(|&(_, &v)| v == 1).fold(0, |a, (i, _)| a + i + min_sum);
println!("{}", sum);
println!("{}", sum == 115039000);
}
在带有2.3 GHz Intel Core i7的OS X 10.11.5上。
我没有足够的Java经验来了解它可以自动执行哪种优化。
我看到的最大潜力下一步是在执行添加时利用SIMD指令;它几乎就是SIMD的制作方式。
作为pointed out by Eli Friedman,通过压缩isn't currently the most performant way来避免数组索引。
通过以下更改,现在时间 1.267s 。
let xx = &mut current_x[e..e3];
xx.copy_from_slice(&prev_y[0..e2]);
let yy = ¤t_y[e..e3];
for i in 0..(e3-e) {
xx[i] += yy[i];
}
这会生成看似展开循环以及使用SIMD指令的程序集:
+0x9b0 movdqu -48(%rsi), %xmm0
+0x9b5 movdqu -48(%rcx), %xmm1
+0x9ba paddd %xmm0, %xmm1
+0x9be movdqu %xmm1, -48(%rsi)
+0x9c3 movdqu -32(%rsi), %xmm0
+0x9c8 movdqu -32(%rcx), %xmm1
+0x9cd paddd %xmm0, %xmm1
+0x9d1 movdqu %xmm1, -32(%rsi)
+0x9d6 movdqu -16(%rsi), %xmm0
+0x9db movdqu -16(%rcx), %xmm1
+0x9e0 paddd %xmm0, %xmm1
+0x9e4 movdqu %xmm1, -16(%rsi)
+0x9e9 movdqu (%rsi), %xmm0
+0x9ed movdqu (%rcx), %xmm1
+0x9f1 paddd %xmm0, %xmm1
+0x9f5 movdqu %xmm1, (%rsi)
+0x9f9 addq $64, %rcx
+0x9fd addq $64, %rsi
+0xa01 addq $-16, %rdx
+0xa05 jne "slow::main+0x9b0"