有没有一种方法可以写入纳尔代布拉矩阵的整个行/列?

时间:2019-06-24 14:52:15

标签: rust

我正在使用DMatrix结构来分配动态大小的矩阵,在该矩阵中,我使用L2范数通过归一化列向量重复覆盖每一列。

// a is some DMatrix of arbitrary size
let col_0 = a.column(0);
let norm_of_col_0 = col_0.normalize();

而不是循环遍历当前列中的每个单元格:

let row = a.shape().0;
let col = a.shape().1;
for col in 0..ncols {
    let norm_of_col = a.column(col).normalize();
    for row in 0..nrows {
        *a.index_mut((row, col)) = norm_of_col()[row];
    }
}

我想用其规范化版本直接覆盖该列。该代码在语义上应如下所示:

*a.index_mut((_, col)) = norm_of_col();

其中(_, col)表示我选择col列,而_表示整行。

更笼统地说,是否有一种方法可以用相同大小和数据类型的新行或列覆盖行或列? 方法insert_columns仅将列添加到现有矩阵中。

如果是这样,这样做的计算速度更快,还是我应该编写一个遍历每个单元格以更新矩阵的辅助方法?

1 个答案:

答案 0 :(得分:1)

您可以使用nalgebra 0.18.0这样操作:

use nalgebra::DMatrix;

fn main() {
    let mut m = DMatrix::from_vec(2, 3, (0 .. 6).map(|n| n as f64).collect());
    dbg!(&m);
    for mut col in m.column_iter_mut() {
        let normalized = col.normalize();
        col.copy_from(&normalized);
    }
    dbg!(&m);
}

与您的代码相比,我还没有衡量该代码的性能。

请注意,copy_from在每个步骤中都没有检查边界,而是仅在循环之前检查一次。我没有检查优化器是否可以在您的代码中执行等效转换。这个简单的基准为我的机器上的答案提供了解决方案的优势(不确定其代表性如何;通常使用基准免责声明):

use criterion::{black_box, criterion_group, criterion_main, Benchmark, Criterion};
use nalgebra::DMatrix;

fn normalize_lib(m: &mut DMatrix<f64>) {
    for mut col in m.column_iter_mut() {
        let normalized = col.normalize();
        col.copy_from(&normalized);
    }
}

fn normalize_hand_rolled(a: &mut DMatrix<f64>) {
    let nrows = a.shape().0;
    let ncols = a.shape().1;
    for col in 0..ncols {
        let norm_of_col = a.column(col).normalize();
        for row in 0..nrows {
            *a.index_mut((row, col)) = norm_of_col[row];
        }
    }
}

fn benchmark(c: &mut Criterion) {
    let mut m0 = DMatrix::new_random(100, 100);
    let mut m1 = m0.clone();
    let bench = Benchmark::new("lib", move |b| b.iter(|| normalize_lib(black_box(&mut m0))))
        .with_function("hand_rolled", move |b| {
            b.iter(|| normalize_hand_rolled(black_box(&mut m1)))
        });
    c.bench("normalize", bench);
}

criterion_group!(benches, benchmark);
criterion_main!(benches);
normalize/lib           time:   [26.102 us 26.245 us 26.443 us]
normalize/hand_rolled   time:   [37.013 us 37.057 us 37.106 us]